From 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Oct 2022 21:56:54 +0200 Subject: Update --- data/csv.py | 162 +++++++++++---------- dreambooth.py | 75 ++++++---- environment.yaml | 2 +- infer.py | 12 +- .../stable_diffusion/vlpn_stable_diffusion.py | 5 +- textual_inversion.py | 57 ++++---- 6 files changed, 169 insertions(+), 144 deletions(-) diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,27 +1,38 @@ +import math import pandas as pd from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms +from typing import NamedTuple, List + + +class CSVDataItem(NamedTuple): + instance_image_path: Path + class_image_path: Path + prompt: str + nprompt: str class CSVDataModule(pl.LightningDataModule): - def __init__(self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="db_cls", - num_class_images=2, - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, - collate_fn=None): + def __init__( + self, + batch_size, + data_file, + tokenizer, + instance_identifier, + class_identifier=None, + class_subdir="db_cls", + num_class_images=100, + size=512, + repeats=100, + interpolation="bicubic", + center_crop=False, + valid_set_size=None, + generator=None, + collate_fn=None + ): super().__init__() self.data_file = Path(data_file) @@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): self.collate_fn = collate_fn self.batch_size = batch_size + def prepare_subdata(self, data, num_class_images=1): + image_multiplier = max(math.ceil(num_class_images / len(data)), 1) + + return [ + CSVDataItem( + self.data_root.joinpath(item.image), + self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), + item.prompt, + item.nprompt if "nprompt" in item else "" + ) + for item in data + if "skip" not in item or item.skip != "x" + for i in range(image_multiplier) + ] + def prepare_data(self): metadata = pd.read_csv(self.data_file) - instance_image_paths = [ - self.data_root.joinpath(f) - for f in metadata['image'].values - for i in range(self.num_class_images) - ] - class_image_paths = [ - self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") - for f in metadata['image'].values - for i in range(self.num_class_images) - ] - prompts = [ - prompt - for prompt in metadata['prompt'].values - for i in range(self.num_class_images) - ] - nprompts = [ - nprompt - for nprompt in metadata['nprompt'].values - for i in range(self.num_class_images) - ] if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = [ - skip - for skip in metadata['skip'].values - for i in range(self.num_class_images) - ] if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [ - (i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x" - ] + metadata = list(metadata.itertuples()) + num_images = len(metadata) - def setup(self, stage=None): - valid_set_size = int(len(self.data) * 0.2) + valid_set_size = int(num_images * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) - train_set_size = len(self.data) - valid_set_size + train_set_size = num_images - valid_set_size - self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) - train_dataset = CSVDataset(self.data_train, self.tokenizer, + self.data_train = self.prepare_subdata(data_train, self.num_class_images) + self.data_val = self.prepare_subdata(data_val) + + def setup(self, stage=None): + train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, + val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): @@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): class CSVDataset(Dataset): - def __init__(self, - data, - tokenizer, - instance_identifier, - class_identifier=None, - num_class_images=2, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, - ): + def __init__( + self, + data: List[CSVDataItem], + tokenizer, + instance_identifier, + batch_size=1, + class_identifier=None, + num_class_images=0, + size=512, + repeats=1, + interpolation="bicubic", + center_crop=False, + ): self.data = data self.tokenizer = tokenizer + self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images self.cache = {} + self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -149,46 +153,50 @@ class CSVDataset(Dataset): ) def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): - instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - cache_key = f"{instance_image_path}_{class_image_path}" + item = self.data[i % self.num_instance_images] + cache_key = f"{item.instance_image_path}_{item.class_image_path}" if cache_key in self.cache: return self.cache[cache_key] example = {} - example["prompts"] = prompt - example["nprompts"] = nprompt + example["prompts"] = item.prompt + example["nprompts"] = item.nprompt - instance_image = Image.open(instance_image_path) - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") + if item.instance_image_path in self.image_cache: + instance_image = self.image_cache[item.instance_image_path] + else: + instance_image = Image.open(item.instance_image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + self.image_cache[item.instance_image_path] = instance_image example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - prompt.format(self.instance_identifier), + item.prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.num_class_images != 0: - class_image = Image.open(class_image_path) + class_image = Image.open(item.class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - prompt.format(self.class_identifier), + item.prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - self.cache[instance_image_path] = example + self.cache[item.instance_image_path] = example return example def __getitem__(self, i): diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -3,6 +3,7 @@ import math import os import datetime import logging +import json from pathlib import Path import numpy as np @@ -21,7 +22,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -import json from data.csv import CSVDataModule @@ -68,7 +68,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=4, + default=200, help="How many class images to generate per training image." ) parser.add_argument( @@ -140,7 +140,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="constant", + default="linear", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -198,6 +198,12 @@ def parse_args(): default=-1, help="For distributed training: local_rank" ) + parser.add_argument( + "--sample_frequency", + type=int, + default=100, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_image_size", type=int, @@ -366,20 +372,20 @@ class Checkpointer: generator=generator, ) - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt.format(self.instance_identifier) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, negative_prompt=nprompt, @@ -393,15 +399,15 @@ class Checkpointer: output_type='pil' )["sample"] - all_samples += samples + all_samples += samples - del samples + del samples - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path) - del all_samples - del image_grid + del all_samples + del image_grid del unwrapped del scheduler @@ -538,7 +544,7 @@ def main(): datamodule.setup() if args.num_class_images != 0: - missing_data = [item for item in datamodule.data if not item[1].exists()] + missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: batched_data = [missing_data[i:i+args.sample_batch_size] @@ -558,20 +564,20 @@ def main(): pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) - for batch in batched_data: - image_name = [p[1] for p in batch] - prompt = [p[2].format(args.class_identifier) for p in batch] - nprompt = [p[3] for p in batch] + with torch.inference_mode(): + for batch in batched_data: + image_name = [p.class_image_path for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] + nprompt = [p.nprompt for p in batch] - with accelerator.autocast(): images = pipeline( prompt=prompt, negative_prompt=nprompt, num_inference_steps=args.sample_steps ).images - for i, image in enumerate(images): - image.save(image_name[i]) + for i, image in enumerate(images): + image.save(image_name[i]) del pipeline @@ -677,6 +683,8 @@ def main(): unet.train() train_loss = 0.0 + sample_checkpoint = False + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space @@ -737,6 +745,9 @@ def main(): global_step += 1 + if global_step % args.sample_frequency == 0: + sample_checkpoint = True + logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) @@ -783,7 +794,11 @@ def main(): val_loss /= len(val_dataloader) - accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + accelerator.log({ + "train/loss": train_loss, + "val/loss": val_loss, + "lr": lr_scheduler.get_last_lr()[0] + }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() @@ -792,7 +807,7 @@ def main(): accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss - if accelerator.is_main_process: + if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( global_step, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) diff --git a/environment.yaml b/environment.yaml index c9f498e..5ecc5a8 100644 --- a/environment.yaml +++ b/environment.yaml @@ -32,6 +32,6 @@ dependencies: - test-tube>=0.7.5 - torch-fidelity==0.3.0 - torchmetrics==0.9.3 - - transformers==4.22.1 + - transformers==4.22.2 - triton==2.0.0.dev20220924 - xformers==0.0.13 diff --git a/infer.py b/infer.py index 6197aa3..a542534 100644 --- a/infer.py +++ b/infer.py @@ -5,12 +5,11 @@ import sys import shlex import cmd from pathlib import Path -from torch import autocast import torch import json from PIL import Image -from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler -from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler +from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from schedulers.scheduling_euler_a import EulerAScheduler @@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, "scheduler": "euler_a", - "precision": "bf16", + "precision": "fp16", "embeddings_dir": "embeddings", "output_dir": "output/inference", "config": None, @@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args): else: init_image = None - with autocast("cuda"): + with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( desc=f"Batch {i + 1} of {args.batch_num}", @@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd): args = run_parser(self.parser, default_cmds, elements) except SystemExit: self.parser.print_help() + except Exception as e: + print(e) + return if len(args.prompt) == 0: print('Try again with a prompt!') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a198cf6..bfecd1c 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) elif isinstance(latents, PIL.Image.Image): latents = preprocess(latents, width, height) - latent_dist = self.vae.encode(latents.to(self.device)).latent_dist + latents = latents.to(device=self.device, dtype=latents_dtype) + latent_dist = self.vae.encode(latents).latent_dist latents = latent_dist.sample(generator=generator) latents = 0.18215 * latents @@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline): timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps - noise = torch.randn(latents.shape, generator=generator, device=self.device) + noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) latents = self.scheduler.add_noise(latents, noise, timesteps) else: if latents.shape != latents_shape: diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -4,6 +4,7 @@ import math import os import datetime import logging +import json from pathlib import Path import numpy as np @@ -22,8 +23,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -import json -import os from data.csv import CSVDataModule @@ -70,7 +69,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=4, + default=200, help="How many class images to generate per training image." ) parser.add_argument( @@ -141,7 +140,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="constant", + default="linear", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -402,20 +401,20 @@ class Checkpointer: generator=generator, ) - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt.format(self.placeholder_token) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, negative_prompt=nprompt, @@ -429,15 +428,15 @@ class Checkpointer: output_type='pil' )["sample"] - all_samples += samples + all_samples += samples - del samples + del samples - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path) - del all_samples - del image_grid + del all_samples + del image_grid del unwrapped del scheduler @@ -623,7 +622,7 @@ def main(): datamodule.setup() if args.num_class_images != 0: - missing_data = [item for item in datamodule.data if not item[1].exists()] + missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: batched_data = [missing_data[i:i+args.sample_batch_size] @@ -643,20 +642,20 @@ def main(): pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) - for batch in batched_data: - image_name = [p[1] for p in batch] - prompt = [p[2].format(args.initializer_token) for p in batch] - nprompt = [p[3] for p in batch] + with torch.inference_mode(): + for batch in batched_data: + image_name = [p.class_image_path for p in batch] + prompt = [p.prompt.format(args.initializer_token) for p in batch] + nprompt = [p.nprompt for p in batch] - with accelerator.autocast(): images = pipeline( prompt=prompt, negative_prompt=nprompt, num_inference_steps=args.sample_steps ).images - for i, image in enumerate(images): - image.save(image_name[i]) + for i, image in enumerate(images): + image.save(image_name[i]) del pipeline -- cgit v1.2.3-70-g09d2