diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1bff414..dd93e09 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -497,12 +497,12 @@ def main(): | |||
497 | pixel_values = [example["instance_images"] for example in examples] | 497 | pixel_values = [example["instance_images"] for example in examples] |
498 | 498 | ||
499 | # concat class and instance examples for prior preservation | 499 | # concat class and instance examples for prior preservation |
500 | if args.class_identifier and "class_prompt_ids" in examples[0]: | 500 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: |
501 | input_ids += [example["class_prompt_ids"] for example in examples] | 501 | input_ids += [example["class_prompt_ids"] for example in examples] |
502 | pixel_values += [example["class_images"] for example in examples] | 502 | pixel_values += [example["class_images"] for example in examples] |
503 | 503 | ||
504 | pixel_values = torch.stack(pixel_values) | 504 | pixel_values = torch.stack(pixel_values) |
505 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | 505 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
506 | 506 | ||
507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 507 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
508 | 508 | ||
@@ -529,7 +529,7 @@ def main(): | |||
529 | datamodule.prepare_data() | 529 | datamodule.prepare_data() |
530 | datamodule.setup() | 530 | datamodule.setup() |
531 | 531 | ||
532 | if args.class_identifier: | 532 | if args.class_identifier is not None: |
533 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 533 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
534 | 534 | ||
535 | if len(missing_data) != 0: | 535 | if len(missing_data) != 0: |
@@ -686,7 +686,7 @@ def main(): | |||
686 | # Predict the noise residual | 686 | # Predict the noise residual |
687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
688 | 688 | ||
689 | if args.class_identifier: | 689 | if args.class_identifier is not None: |
690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |