From 8f4d212b3833041448678ad8a44a9a327934f74a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 15 Dec 2022 20:30:59 +0100 Subject: Avoid increased VRAM usage on validation --- textual_inversion.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -111,6 +111,12 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--mode", + type=str, + default=None, + help="A mode to filter the dataset.", + ) parser.add_argument( "--seed", type=int, @@ -679,6 +685,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, @@ -940,7 +947,7 @@ def main(): text_encoder.eval() val_loss = 0.0 - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -958,8 +965,6 @@ def main(): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise -- cgit v1.2.3-54-g00ecf