summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 96213d0..3eecf9c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -130,6 +130,12 @@ def parse_args():
130 help="The embeddings directory where Textual Inversion embeddings are stored.", 130 help="The embeddings directory where Textual Inversion embeddings are stored.",
131 ) 131 )
132 parser.add_argument( 132 parser.add_argument(
133 "--mode",
134 type=str,
135 default=None,
136 help="A mode to filter the dataset.",
137 )
138 parser.add_argument(
133 "--seed", 139 "--seed",
134 type=int, 140 type=int,
135 default=None, 141 default=None,
@@ -284,7 +290,7 @@ def parse_args():
284 parser.add_argument( 290 parser.add_argument(
285 "--sample_frequency", 291 "--sample_frequency",
286 type=int, 292 type=int,
287 default=100, 293 default=1,
288 help="How often to save a checkpoint and sample image", 294 help="How often to save a checkpoint and sample image",
289 ) 295 )
290 parser.add_argument( 296 parser.add_argument(
@@ -759,6 +765,7 @@ def main():
759 num_class_images=args.num_class_images, 765 num_class_images=args.num_class_images,
760 size=args.resolution, 766 size=args.resolution,
761 repeats=args.repeats, 767 repeats=args.repeats,
768 mode=args.mode,
762 dropout=args.tag_dropout, 769 dropout=args.tag_dropout,
763 center_crop=args.center_crop, 770 center_crop=args.center_crop,
764 template_key=args.train_data_template, 771 template_key=args.train_data_template,
@@ -1046,7 +1053,7 @@ def main():
1046 unet.eval() 1053 unet.eval()
1047 text_encoder.eval() 1054 text_encoder.eval()
1048 1055
1049 with torch.autocast("cuda"), torch.inference_mode(): 1056 with torch.inference_mode():
1050 for step, batch in enumerate(val_dataloader): 1057 for step, batch in enumerate(val_dataloader):
1051 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 1058 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
1052 latents = latents * 0.18215 1059 latents = latents * 0.18215
@@ -1063,8 +1070,6 @@ def main():
1063 1070
1064 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1071 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1065 1072
1066 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
1067
1068 # Get the target for loss depending on the prediction type 1073 # Get the target for loss depending on the prediction type
1069 if noise_scheduler.config.prediction_type == "epsilon": 1074 if noise_scheduler.config.prediction_type == "epsilon":
1070 target = noise 1075 target = noise