From a13bd300040f50c27548ad9cc5d9c9f4a3d4f503 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 22:07:45 +0200 Subject: Fix --- dreambooth.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 2fe89ec..1bff414 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -235,12 +235,6 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) - parser.add_argument( - "--with_prior_preservation", - default=False, - action="store_true", - help="Flag to add prior perservation loss.", - ) parser.add_argument( "--max_grad_norm", default=1.0, @@ -503,7 +497,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.with_prior_preservation and "class_prompt_ids" in examples[0]: + if args.class_identifier and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -692,7 +686,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.with_prior_preservation: + if args.class_identifier: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) -- cgit v1.2.3-70-g09d2