diff options
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 2 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 2fe89ec..1bff414 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -236,12 +236,6 @@ def parse_args(): | |||
236 | help="The weight of prior preservation loss." | 236 | help="The weight of prior preservation loss." |
237 | ) | 237 | ) |
238 | parser.add_argument( | 238 | parser.add_argument( |
239 | "--with_prior_preservation", | ||
240 | default=False, | ||
241 | action="store_true", | ||
242 | help="Flag to add prior perservation loss.", | ||
243 | ) | ||
244 | parser.add_argument( | ||
245 | "--max_grad_norm", | 239 | "--max_grad_norm", |
246 | default=1.0, | 240 | default=1.0, |
247 | type=float, | 241 | type=float, |
@@ -503,7 +497,7 @@ def main(): | |||
503 | pixel_values = [example["instance_images"] for example in examples] | 497 | pixel_values = [example["instance_images"] for example in examples] |
504 | 498 | ||
505 | # concat class and instance examples for prior preservation | 499 | # concat class and instance examples for prior preservation |
506 | if args.with_prior_preservation and "class_prompt_ids" in examples[0]: | 500 | if args.class_identifier and "class_prompt_ids" in examples[0]: |
507 | input_ids += [example["class_prompt_ids"] for example in examples] | 501 | input_ids += [example["class_prompt_ids"] for example in examples] |
508 | pixel_values += [example["class_images"] for example in examples] | 502 | pixel_values += [example["class_images"] for example in examples] |
509 | 503 | ||
@@ -692,7 +686,7 @@ def main(): | |||
692 | # Predict the noise residual | 686 | # Predict the noise residual |
693 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
694 | 688 | ||
695 | if args.with_prior_preservation: | 689 | if args.class_identifier: |
696 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
697 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
698 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |