summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index a849d2a..e281c73 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -112,6 +112,12 @@ def parse_args():
112 help="The embeddings directory where Textual Inversion embeddings are stored.", 112 help="The embeddings directory where Textual Inversion embeddings are stored.",
113 ) 113 )
114 parser.add_argument( 114 parser.add_argument(
115 "--mode",
116 type=str,
117 default=None,
118 help="A mode to filter the dataset.",
119 )
120 parser.add_argument(
115 "--seed", 121 "--seed",
116 type=int, 122 type=int,
117 default=None, 123 default=None,
@@ -679,6 +685,7 @@ def main():
679 num_class_images=args.num_class_images, 685 num_class_images=args.num_class_images,
680 size=args.resolution, 686 size=args.resolution,
681 repeats=args.repeats, 687 repeats=args.repeats,
688 mode=args.mode,
682 dropout=args.tag_dropout, 689 dropout=args.tag_dropout,
683 center_crop=args.center_crop, 690 center_crop=args.center_crop,
684 template_key=args.train_data_template, 691 template_key=args.train_data_template,
@@ -940,7 +947,7 @@ def main():
940 text_encoder.eval() 947 text_encoder.eval()
941 val_loss = 0.0 948 val_loss = 0.0
942 949
943 with torch.autocast("cuda"), torch.inference_mode(): 950 with torch.inference_mode():
944 for step, batch in enumerate(val_dataloader): 951 for step, batch in enumerate(val_dataloader):
945 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 952 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
946 latents = latents * 0.18215 953 latents = latents * 0.18215
@@ -958,8 +965,6 @@ def main():
958 965
959 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 966 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
960 967
961 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
962
963 # Get the target for loss depending on the prediction type 968 # Get the target for loss depending on the prediction type
964 if noise_scheduler.config.prediction_type == "epsilon": 969 if noise_scheduler.config.prediction_type == "epsilon":
965 target = noise 970 target = noise