From 8f4d212b3833041448678ad8a44a9a327934f74a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 15 Dec 2022 20:30:59 +0100 Subject: Avoid increased VRAM usage on validation --- data/csv.py | 13 ++++++++++--- dreambooth.py | 13 +++++++++---- infer.py | 4 ++-- textual_inversion.py | 11 ++++++++--- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/data/csv.py b/data/csv.py index 20ac992..053457b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule): dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, + mode: Optional[str] = None, template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, @@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule): self.repeats = repeats self.dropout = dropout self.center_crop = center_crop + self.mode = mode self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root.joinpath(image.format(item["image"])), None, prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), - nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) + nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data ] @@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule): item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, - item.nprompt + item.nprompt, ) for item in items for i in range(image_multiplier) @@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule): template = metadata[self.template_key] if self.template_key in metadata else {} items = metadata["items"] if "items" in metadata else [] - items = [item for item in items if not "skip" in item or item["skip"] != True] + if self.mode is not None: + items = [ + item + for item in items + if "mode" in item and self.mode in item["mode"] + ] items = self.prepare_items(template, items) items = self.filter_items(items) diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -129,6 +129,12 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--mode", + type=str, + default=None, + help="A mode to filter the dataset.", + ) parser.add_argument( "--seed", type=int, @@ -284,7 +290,7 @@ def parse_args(): parser.add_argument( "--sample_frequency", type=int, - default=100, + default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( @@ -759,6 +765,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, @@ -1046,7 +1053,7 @@ def main(): unet.eval() text_encoder.eval() - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -1063,8 +1070,6 @@ def main(): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise diff --git a/infer.py b/infer.py index efeb24d..420cb83 100644 --- a/infer.py +++ b/infer.py @@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", - "ti_embeddings_dir": "embeddings_ti", + "ti_embeddings_dir": "embeddings", "output_dir": "output/inference", "config": None, } @@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype): unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") pipeline = VlpnStableDiffusion( diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -111,6 +111,12 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--mode", + type=str, + default=None, + help="A mode to filter the dataset.", + ) parser.add_argument( "--seed", type=int, @@ -679,6 +685,7 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, + mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, @@ -940,7 +947,7 @@ def main(): text_encoder.eval() val_loss = 0.0 - with torch.autocast("cuda"), torch.inference_mode(): + with torch.inference_mode(): for step, batch in enumerate(val_dataloader): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -958,8 +965,6 @@ def main(): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise -- cgit v1.2.3-70-g09d2