From 8f4d212b3833041448678ad8a44a9a327934f74a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 15 Dec 2022 20:30:59 +0100 Subject: Avoid increased VRAM usage on validation --- dreambooth.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -129,6 +129,12 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--mode", + type=str, + default=None, + help="A mode to filter the dataset.", + ) parser.add_argument( "--seed", type=int, @@ -284,7 +290,7 @@ def parse_args(): parser.add_argument( "--sample_frequency", type=int, - default=100, + default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( @@ -759,6 +765,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, @@ -1046,7 +1053,7 @@ def main(): unet.eval() text_encoder.eval() - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -1063,8 +1070,6 @@ def main(): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise -- cgit v1.2.3-54-g00ecf