diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -112,6 +112,12 @@ def parse_args(): | |||
112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
113 | ) | 113 | ) |
114 | parser.add_argument( | 114 | parser.add_argument( |
115 | "--mode", | ||
116 | type=str, | ||
117 | default=None, | ||
118 | help="A mode to filter the dataset.", | ||
119 | ) | ||
120 | parser.add_argument( | ||
115 | "--seed", | 121 | "--seed", |
116 | type=int, | 122 | type=int, |
117 | default=None, | 123 | default=None, |
@@ -679,6 +685,7 @@ def main(): | |||
679 | num_class_images=args.num_class_images, | 685 | num_class_images=args.num_class_images, |
680 | size=args.resolution, | 686 | size=args.resolution, |
681 | repeats=args.repeats, | 687 | repeats=args.repeats, |
688 | mode=args.mode, | ||
682 | dropout=args.tag_dropout, | 689 | dropout=args.tag_dropout, |
683 | center_crop=args.center_crop, | 690 | center_crop=args.center_crop, |
684 | template_key=args.train_data_template, | 691 | template_key=args.train_data_template, |
@@ -940,7 +947,7 @@ def main(): | |||
940 | text_encoder.eval() | 947 | text_encoder.eval() |
941 | val_loss = 0.0 | 948 | val_loss = 0.0 |
942 | 949 | ||
943 | with torch.autocast("cuda"), torch.inference_mode(): | 950 | with torch.inference_mode(): |
944 | for step, batch in enumerate(val_dataloader): | 951 | for step, batch in enumerate(val_dataloader): |
945 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 952 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
946 | latents = latents * 0.18215 | 953 | latents = latents * 0.18215 |
@@ -958,8 +965,6 @@ def main(): | |||
958 | 965 | ||
959 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 966 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
960 | 967 | ||
961 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
962 | |||
963 | # Get the target for loss depending on the prediction type | 968 | # Get the target for loss depending on the prediction type |
964 | if noise_scheduler.config.prediction_type == "epsilon": | 969 | if noise_scheduler.config.prediction_type == "epsilon": |
965 | target = noise | 970 | target = noise |