From 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Oct 2022 21:56:54 +0200 Subject: Update --- textual_inversion.py | 57 ++++++++++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 29 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -4,6 +4,7 @@ import math import os import datetime import logging +import json from pathlib import Path import numpy as np @@ -22,8 +23,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -import json -import os from data.csv import CSVDataModule @@ -70,7 +69,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=4, + default=200, help="How many class images to generate per training image." ) parser.add_argument( @@ -141,7 +140,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="constant", + default="linear", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -402,20 +401,20 @@ class Checkpointer: generator=generator, ) - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt.format(self.placeholder_token) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, negative_prompt=nprompt, @@ -429,15 +428,15 @@ class Checkpointer: output_type='pil' )["sample"] - all_samples += samples + all_samples += samples - del samples + del samples - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path) - del all_samples - del image_grid + del all_samples + del image_grid del unwrapped del scheduler @@ -623,7 +622,7 @@ def main(): datamodule.setup() if args.num_class_images != 0: - missing_data = [item for item in datamodule.data if not item[1].exists()] + missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: batched_data = [missing_data[i:i+args.sample_batch_size] @@ -643,20 +642,20 @@ def main(): pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) - for batch in batched_data: - image_name = [p[1] for p in batch] - prompt = [p[2].format(args.initializer_token) for p in batch] - nprompt = [p[3] for p in batch] + with torch.inference_mode(): + for batch in batched_data: + image_name = [p.class_image_path for p in batch] + prompt = [p.prompt.format(args.initializer_token) for p in batch] + nprompt = [p.nprompt for p in batch] - with accelerator.autocast(): images = pipeline( prompt=prompt, negative_prompt=nprompt, num_inference_steps=args.sample_steps ).images - for i, image in enumerate(images): - image.save(image_name[i]) + for i, image in enumerate(images): + image.save(image_name[i]) del pipeline -- cgit v1.2.3-54-g00ecf