diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -130,6 +130,12 @@ def parse_args(): | |||
| 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 131 | ) | 131 | ) |
| 132 | parser.add_argument( | 132 | parser.add_argument( |
| 133 | "--mode", | ||
| 134 | type=str, | ||
| 135 | default=None, | ||
| 136 | help="A mode to filter the dataset.", | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 133 | "--seed", | 139 | "--seed", |
| 134 | type=int, | 140 | type=int, |
| 135 | default=None, | 141 | default=None, |
| @@ -284,7 +290,7 @@ def parse_args(): | |||
| 284 | parser.add_argument( | 290 | parser.add_argument( |
| 285 | "--sample_frequency", | 291 | "--sample_frequency", |
| 286 | type=int, | 292 | type=int, |
| 287 | default=100, | 293 | default=1, |
| 288 | help="How often to save a checkpoint and sample image", | 294 | help="How often to save a checkpoint and sample image", |
| 289 | ) | 295 | ) |
| 290 | parser.add_argument( | 296 | parser.add_argument( |
| @@ -759,6 +765,7 @@ def main(): | |||
| 759 | num_class_images=args.num_class_images, | 765 | num_class_images=args.num_class_images, |
| 760 | size=args.resolution, | 766 | size=args.resolution, |
| 761 | repeats=args.repeats, | 767 | repeats=args.repeats, |
| 768 | mode=args.mode, | ||
| 762 | dropout=args.tag_dropout, | 769 | dropout=args.tag_dropout, |
| 763 | center_crop=args.center_crop, | 770 | center_crop=args.center_crop, |
| 764 | template_key=args.train_data_template, | 771 | template_key=args.train_data_template, |
| @@ -1046,7 +1053,7 @@ def main(): | |||
| 1046 | unet.eval() | 1053 | unet.eval() |
| 1047 | text_encoder.eval() | 1054 | text_encoder.eval() |
| 1048 | 1055 | ||
| 1049 | with torch.autocast("cuda"), torch.inference_mode(): | 1056 | with torch.inference_mode(): |
| 1050 | for step, batch in enumerate(val_dataloader): | 1057 | for step, batch in enumerate(val_dataloader): |
| 1051 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 1058 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 1052 | latents = latents * 0.18215 | 1059 | latents = latents * 0.18215 |
| @@ -1063,8 +1070,6 @@ def main(): | |||
| 1063 | 1070 | ||
| 1064 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1071 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 1065 | 1072 | ||
| 1066 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 1067 | |||
| 1068 | # Get the target for loss depending on the prediction type | 1073 | # Get the target for loss depending on the prediction type |
| 1069 | if noise_scheduler.config.prediction_type == "epsilon": | 1074 | if noise_scheduler.config.prediction_type == "epsilon": |
| 1070 | target = noise | 1075 | target = noise |
