summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 09:19:52 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 09:19:52 +0200
commit0462f5d20fdc82806753629b6f5c5fb39a88d1d2 (patch)
treefb9f73860a597606539471a27f583c4a8337961d /dreambooth.py
parentFix (diff)
downloadtextual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.tar.gz
textual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.tar.bz2
textual-inversion-diff-0462f5d20fdc82806753629b6f5c5fb39a88d1d2.zip
Bugfix
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1bff414..dd93e09 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -497,12 +497,12 @@ def main():
497 pixel_values = [example["instance_images"] for example in examples] 497 pixel_values = [example["instance_images"] for example in examples]
498 498
499 # concat class and instance examples for prior preservation 499 # concat class and instance examples for prior preservation
500 if args.class_identifier and "class_prompt_ids" in examples[0]: 500 if args.class_identifier is not None and "class_prompt_ids" in examples[0]:
501 input_ids += [example["class_prompt_ids"] for example in examples] 501 input_ids += [example["class_prompt_ids"] for example in examples]
502 pixel_values += [example["class_images"] for example in examples] 502 pixel_values += [example["class_images"] for example in examples]
503 503
504 pixel_values = torch.stack(pixel_values) 504 pixel_values = torch.stack(pixel_values)
505 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 505 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
506 506
507 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 507 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
508 508
@@ -529,7 +529,7 @@ def main():
529 datamodule.prepare_data() 529 datamodule.prepare_data()
530 datamodule.setup() 530 datamodule.setup()
531 531
532 if args.class_identifier: 532 if args.class_identifier is not None:
533 missing_data = [item for item in datamodule.data if not item[1].exists()] 533 missing_data = [item for item in datamodule.data if not item[1].exists()]
534 534
535 if len(missing_data) != 0: 535 if len(missing_data) != 0:
@@ -686,7 +686,7 @@ def main():
686 # Predict the noise residual 686 # Predict the noise residual
687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
688 688
689 if args.class_identifier: 689 if args.class_identifier is not None:
690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
692 noise, noise_prior = torch.chunk(noise, 2, dim=0) 692 noise, noise_prior = torch.chunk(noise, 2, dim=0)