diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -130,6 +130,12 @@ def parse_args(): | |||
130 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
131 | ) | 131 | ) |
132 | parser.add_argument( | 132 | parser.add_argument( |
133 | "--mode", | ||
134 | type=str, | ||
135 | default=None, | ||
136 | help="A mode to filter the dataset.", | ||
137 | ) | ||
138 | parser.add_argument( | ||
133 | "--seed", | 139 | "--seed", |
134 | type=int, | 140 | type=int, |
135 | default=None, | 141 | default=None, |
@@ -284,7 +290,7 @@ def parse_args(): | |||
284 | parser.add_argument( | 290 | parser.add_argument( |
285 | "--sample_frequency", | 291 | "--sample_frequency", |
286 | type=int, | 292 | type=int, |
287 | default=100, | 293 | default=1, |
288 | help="How often to save a checkpoint and sample image", | 294 | help="How often to save a checkpoint and sample image", |
289 | ) | 295 | ) |
290 | parser.add_argument( | 296 | parser.add_argument( |
@@ -759,6 +765,7 @@ def main(): | |||
759 | num_class_images=args.num_class_images, | 765 | num_class_images=args.num_class_images, |
760 | size=args.resolution, | 766 | size=args.resolution, |
761 | repeats=args.repeats, | 767 | repeats=args.repeats, |
768 | mode=args.mode, | ||
762 | dropout=args.tag_dropout, | 769 | dropout=args.tag_dropout, |
763 | center_crop=args.center_crop, | 770 | center_crop=args.center_crop, |
764 | template_key=args.train_data_template, | 771 | template_key=args.train_data_template, |
@@ -1046,7 +1053,7 @@ def main(): | |||
1046 | unet.eval() | 1053 | unet.eval() |
1047 | text_encoder.eval() | 1054 | text_encoder.eval() |
1048 | 1055 | ||
1049 | with torch.autocast("cuda"), torch.inference_mode(): | 1056 | with torch.inference_mode(): |
1050 | for step, batch in enumerate(val_dataloader): | 1057 | for step, batch in enumerate(val_dataloader): |
1051 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 1058 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
1052 | latents = latents * 0.18215 | 1059 | latents = latents * 0.18215 |
@@ -1063,8 +1070,6 @@ def main(): | |||
1063 | 1070 | ||
1064 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1071 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
1065 | 1072 | ||
1066 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
1067 | |||
1068 | # Get the target for loss depending on the prediction type | 1073 | # Get the target for loss depending on the prediction type |
1069 | if noise_scheduler.config.prediction_type == "epsilon": | 1074 | if noise_scheduler.config.prediction_type == "epsilon": |
1070 | target = noise | 1075 | target = noise |