diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -112,6 +112,12 @@ def parse_args(): | |||
| 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 113 | ) | 113 | ) |
| 114 | parser.add_argument( | 114 | parser.add_argument( |
| 115 | "--mode", | ||
| 116 | type=str, | ||
| 117 | default=None, | ||
| 118 | help="A mode to filter the dataset.", | ||
| 119 | ) | ||
| 120 | parser.add_argument( | ||
| 115 | "--seed", | 121 | "--seed", |
| 116 | type=int, | 122 | type=int, |
| 117 | default=None, | 123 | default=None, |
| @@ -679,6 +685,7 @@ def main(): | |||
| 679 | num_class_images=args.num_class_images, | 685 | num_class_images=args.num_class_images, |
| 680 | size=args.resolution, | 686 | size=args.resolution, |
| 681 | repeats=args.repeats, | 687 | repeats=args.repeats, |
| 688 | mode=args.mode, | ||
| 682 | dropout=args.tag_dropout, | 689 | dropout=args.tag_dropout, |
| 683 | center_crop=args.center_crop, | 690 | center_crop=args.center_crop, |
| 684 | template_key=args.train_data_template, | 691 | template_key=args.train_data_template, |
| @@ -940,7 +947,7 @@ def main(): | |||
| 940 | text_encoder.eval() | 947 | text_encoder.eval() |
| 941 | val_loss = 0.0 | 948 | val_loss = 0.0 |
| 942 | 949 | ||
| 943 | with torch.autocast("cuda"), torch.inference_mode(): | 950 | with torch.inference_mode(): |
| 944 | for step, batch in enumerate(val_dataloader): | 951 | for step, batch in enumerate(val_dataloader): |
| 945 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 952 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 946 | latents = latents * 0.18215 | 953 | latents = latents * 0.18215 |
| @@ -958,8 +965,6 @@ def main(): | |||
| 958 | 965 | ||
| 959 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 966 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 960 | 967 | ||
| 961 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 962 | |||
| 963 | # Get the target for loss depending on the prediction type | 968 | # Get the target for loss depending on the prediction type |
| 964 | if noise_scheduler.config.prediction_type == "epsilon": | 969 | if noise_scheduler.config.prediction_type == "epsilon": |
| 965 | target = noise | 970 | target = noise |
