From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- data/csv.py | 47 ++++++++++-- models/clip/embeddings.py | 4 +- .../stable_diffusion/vlpn_stable_diffusion.py | 32 +++----- train_dreambooth.py | 71 ++++-------------- train_ti.py | 86 +++++++--------------- training/common.py | 55 ++++++++++++++ 6 files changed, 149 insertions(+), 146 deletions(-) diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,7 +1,7 @@ import math import torch import json -import copy +from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable @@ -99,6 +99,41 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments +def collate_fn( + num_class_images: int, + weight_dtype: torch.dtype, + prompt_processor: PromptProcessor, + examples +): + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if num_class_images != 0 and "class_prompt_ids" in examples[0]: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) + inputs = prompt_processor.unify_input_ids(input_ids) + + batch = { + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, + "input_ids": inputs.input_ids, + "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, + } + + return batch + + class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -129,7 +164,7 @@ class VlpnDataModule(): valid_set_repeat: int = 1, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, - collate_fn=None, + dtype: torch.dtype = torch.float32, num_workers: int = 0 ): super().__init__() @@ -158,9 +193,9 @@ class VlpnDataModule(): self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter - self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size + self.dtype = dtype def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" @@ -254,14 +289,16 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) + collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) + self.train_dataloader = DataLoader( train_dataset, - batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers ) self.val_dataloader = DataLoader( val_dataset, - batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers ) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 46b414b..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -99,12 +99,12 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, lambda_: float = 1.0): + def normalize(self, target: float = 0.4, lambda_: float = 1.0): w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( w[self.temp_token_ids, :], dim=-1 - ) * (pre_norm + lambda_ * (0.4 - pre_norm)) + ) * (pre_norm + lambda_ * (target - pre_norm)) def forward( self, diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb300d1..6bc40e9 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -20,7 +20,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import logging +from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer from models.clip.prompt import PromptProcessor @@ -250,8 +250,8 @@ class VlpnStableDiffusion(DiffusionPipeline): return timesteps - def prepare_latents(self, batch_size, num_images_per_prompt, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -260,28 +260,16 @@ class VlpnStableDiffusion(DiffusionPipeline): ) if latents is None: - rand_device = "cpu" if device.type == "mps" else device - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents - def prepare_latents_from_image(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) @@ -292,7 +280,7 @@ class VlpnStableDiffusion(DiffusionPipeline): f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * batch_size, dim=0) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) @@ -430,16 +418,14 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = self.prepare_latents_from_image( image, latent_timestep, - batch_size, - num_images_per_prompt, + batch_size * num_images_per_prompt, text_embeddings.dtype, device, generator ) else: latents = self.prepare_latents( - batch_size, - num_images_per_prompt, + batch_size * num_images_per_prompt, num_channels_latents, height, width, diff --git a/train_dreambooth.py b/train_dreambooth.py index ebcf802..da3a075 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -14,7 +14,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm @@ -24,8 +23,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images -from training.optimization import get_one_cycle_schedule +from training.common import loss_step, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings @@ -750,35 +748,6 @@ def main(): ) return cond3 and cond4 - def collate_fn(examples): - prompt_ids = [example["prompt_ids"] for example in examples] - nprompt_ids = [example["nprompt_ids"] for example in examples] - - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - # concat class and instance examples for prior preservation - if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - - prompts = prompt_processor.unify_input_ids(prompt_ids) - nprompts = prompt_processor.unify_input_ids(nprompt_ids) - inputs = prompt_processor.unify_input_ids(input_ids) - - batch = { - "prompt_ids": prompts.input_ids, - "nprompt_ids": nprompts.input_ids, - "input_ids": inputs.input_ids, - "pixel_values": pixel_values, - "attention_mask": inputs.attention_mask, - } - - return batch - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -798,7 +767,7 @@ def main(): num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, - collate_fn=collate_fn + dtype=weight_dtype ) datamodule.prepare_data() @@ -829,33 +798,23 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps - - if args.lr_scheduler == "one_cycle": - lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate - lr_scheduler = get_one_cycle_schedule( - optimizer=optimizer, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - warmup=args.lr_warmup_func, - annealing=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - min_lr=lr_min_lr, - ) - elif args.lr_scheduler == "cosine_with_restarts": - lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), - ) + if args.find_lr: + lr_scheduler = None else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + lr=args.learning_rate, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + warmup_epochs=args.lr_warmup_epochs, + max_train_steps=args.max_train_steps, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps ) unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py @@ -13,7 +13,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from tqdm.auto import tqdm from transformers import CLIPTextModel @@ -22,8 +21,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images -from training.optimization import get_one_cycle_schedule +from training.common import loss_step, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings @@ -410,10 +408,16 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--max_grad_norm", - default=3.0, + "--decay_target", + default=0.4, type=float, - help="Max gradient norm." + help="Embedding decay target." + ) + parser.add_argument( + "--decay_factor", + default=100, + type=float, + help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -709,35 +713,6 @@ def main(): ) return cond1 and cond3 and cond4 - def collate_fn(examples): - prompt_ids = [example["prompt_ids"] for example in examples] - nprompt_ids = [example["nprompt_ids"] for example in examples] - - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - # concat class and instance examples for prior preservation - if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - - prompts = prompt_processor.unify_input_ids(prompt_ids) - nprompts = prompt_processor.unify_input_ids(nprompt_ids) - inputs = prompt_processor.unify_input_ids(input_ids) - - batch = { - "prompt_ids": prompts.input_ids, - "nprompt_ids": nprompts.input_ids, - "input_ids": inputs.input_ids, - "pixel_values": pixel_values, - "attention_mask": inputs.attention_mask, - } - - return batch - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -757,7 +732,7 @@ def main(): num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, - collate_fn=collate_fn + dtype=weight_dtype ) datamodule.setup() @@ -786,35 +761,23 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps - if args.find_lr: lr_scheduler = None - elif args.lr_scheduler == "one_cycle": - lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate - lr_scheduler = get_one_cycle_schedule( - optimizer=optimizer, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - warmup=args.lr_warmup_func, - annealing=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - min_lr=lr_min_lr, - ) - elif args.lr_scheduler == "cosine_with_restarts": - lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles or math.ceil(math.sqrt( - ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), - ) else: lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + lr=args.learning_rate, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + warmup_epochs=args.lr_warmup_epochs, + max_train_steps=args.max_train_steps, + num_update_steps_per_epoch=num_update_steps_per_epoch, + gradient_accumulation_steps=args.gradient_accumulation_steps ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( @@ -868,7 +831,10 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + text_encoder.text_model.embeddings.normalize( + args.decay_target, + min(1.0, args.decay_factor * lr) + ) loop = partial( loss_step, diff --git a/training/common.py b/training/common.py index 0b2ae44..90cf910 100644 --- a/training/common.py +++ b/training/common.py @@ -1,10 +1,65 @@ +import math + import torch import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from training.optimization import get_one_cycle_schedule + + +def get_scheduler( + id: str, + min_lr: float, + lr: float, + warmup_func: str, + annealing_func: str, + warmup_exp: int, + annealing_exp: int, + cycles: int, + warmup_epochs: int, + optimizer: torch.optim.Optimizer, + max_train_steps: int, + num_update_steps_per_epoch: int, + gradient_accumulation_steps: int, +): + warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps + + if id == "one_cycle": + min_lr = 0.04 if min_lr is None else min_lr / lr + + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=max_train_steps * gradient_accumulation_steps, + warmup=warmup_func, + annealing=annealing_func, + warmup_exp=warmup_exp, + annealing_exp=annealing_exp, + min_lr=min_lr, + ) + elif id == "cosine_with_restarts": + cycles = cycles if cycles is not None else math.ceil( + math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) + + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + num_cycles=cycles, + ) + else: + lr_scheduler = get_scheduler_( + id, + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + ) + + return lr_scheduler + def generate_class_images( accelerator, -- cgit v1.2.3-54-g00ecf