summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py10
1 files changed, 2 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2fe89ec..1bff414 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -236,12 +236,6 @@ def parse_args():
236 help="The weight of prior preservation loss." 236 help="The weight of prior preservation loss."
237 ) 237 )
238 parser.add_argument( 238 parser.add_argument(
239 "--with_prior_preservation",
240 default=False,
241 action="store_true",
242 help="Flag to add prior perservation loss.",
243 )
244 parser.add_argument(
245 "--max_grad_norm", 239 "--max_grad_norm",
246 default=1.0, 240 default=1.0,
247 type=float, 241 type=float,
@@ -503,7 +497,7 @@ def main():
503 pixel_values = [example["instance_images"] for example in examples] 497 pixel_values = [example["instance_images"] for example in examples]
504 498
505 # concat class and instance examples for prior preservation 499 # concat class and instance examples for prior preservation
506 if args.with_prior_preservation and "class_prompt_ids" in examples[0]: 500 if args.class_identifier and "class_prompt_ids" in examples[0]:
507 input_ids += [example["class_prompt_ids"] for example in examples] 501 input_ids += [example["class_prompt_ids"] for example in examples]
508 pixel_values += [example["class_images"] for example in examples] 502 pixel_values += [example["class_images"] for example in examples]
509 503
@@ -692,7 +686,7 @@ def main():
692 # Predict the noise residual 686 # Predict the noise residual
693 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
694 688
695 if args.with_prior_preservation: 689 if args.class_identifier:
696 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
697 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
698 noise, noise_prior = torch.chunk(noise, 2, dim=0) 692 noise, noise_prior = torch.chunk(noise, 2, dim=0)