From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- environment.yaml | 4 +- infer.py | 21 +- .../stable_diffusion/vlpn_stable_diffusion.py | 14 +- train_dreambooth.py | 46 +- train_lora.py | 566 +++++++++++++++++++++ train_ti.py | 10 +- training/functional.py | 31 +- training/strategy/dreambooth.py | 35 +- training/strategy/lora.py | 147 ++++++ training/strategy/ti.py | 38 +- 10 files changed, 819 insertions(+), 93 deletions(-) create mode 100644 train_lora.py create mode 100644 training/strategy/lora.py diff --git a/environment.yaml b/environment.yaml index c992759..f5632bf 100644 --- a/environment.yaml +++ b/environment.yaml @@ -18,11 +18,11 @@ dependencies: - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion - accelerate==0.15.0 - - bitsandbytes==0.36.0.post2 + - bitsandbytes==0.37.0 - python-slugify>=6.1.2 - safetensors==0.2.7 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.25.1 - triton==2.0.0.dev20221202 - - xformers==0.0.16.dev430 + - xformers==0.0.17.dev443 diff --git a/infer.py b/infer.py index 2b07b21..42b4e2d 100644 --- a/infer.py +++ b/infer.py @@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", - "ti_embeddings_dir": "embeddings", + "ti_embeddings_dir": "embeddings_ti", + "lora_embeddings_dir": "embeddings_lora", "output_dir": "output/inference", "config": None, } @@ -60,6 +61,7 @@ default_cmds = { "batch_num": 1, "steps": 30, "guidance_scale": 7.0, + "lora_scale": 0.5, "seed": None, "config": None, } @@ -91,6 +93,10 @@ def create_args_parser(): "--ti_embeddings_dir", type=str, ) + parser.add_argument( + "--lora_embeddings_dir", + type=str, + ) parser.add_argument( "--output_dir", type=str, @@ -168,6 +174,10 @@ def create_cmd_parser(): "--guidance_scale", type=float, ) + parser.add_argument( + "--lora_scale", + type=float, + ) parser.add_argument( "--seed", type=int, @@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): generator=generator, image=init_image, strength=args.image_noise, + cross_attention_kwargs={"scale": args.lora_scale}, ).images for j, image in enumerate(images): @@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] - def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): + def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): super().__init__() self.output_dir = output_dir self.ti_embeddings_dir = ti_embeddings_dir + self.lora_embeddings_dir = lora_embeddings_dir self.pipeline = pipeline self.parser = parser @@ -394,9 +406,12 @@ def main(): dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] pipeline = create_pipeline(args.model, dtype) + load_embeddings(pipeline, args.ti_embeddings_dir) + pipeline.unet.load_attn_procs(args.lora_embeddings_dir) + cmd_parser = create_cmd_parser() - cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) + cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import warnings -from typing import List, Optional, Union, Callable +from typing import List, Dict, Any, Optional, Union, Callable import numpy as np import torch @@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: diff --git a/train_dreambooth.py b/train_dreambooth.py index a70c80e..5a4c47b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -73,26 +73,6 @@ def parse_args(): default=None, help="The name of the current project.", ) - parser.add_argument( - "--placeholder_tokens", - type=str, - nargs='*', - default=[], - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--initializer_tokens", - type=str, - nargs='*', - default=[], - help="A token to use as initializer word." - ) - parser.add_argument( - "--num_vectors", - type=int, - nargs='*', - help="Number of vectors per embedding." - ) parser.add_argument( "--exclude_collections", type=str, @@ -436,30 +416,6 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.placeholder_tokens, str): - args.placeholder_tokens = [args.placeholder_tokens] - - if isinstance(args.initializer_tokens, str): - args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - - if len(args.initializer_tokens) == 0: - raise ValueError("You must specify --initializer_tokens") - - if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] - - if len(args.placeholder_tokens) != len(args.initializer_tokens): - raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") - - if args.num_vectors is None: - args.num_vectors = 1 - - if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) - - if len(args.placeholder_tokens) != len(args.num_vectors): - raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if isinstance(args.collection, str): args.collection = [args.collection] @@ -503,7 +459,7 @@ def main(): vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py @@ -0,0 +1,566 @@ +import argparse +import datetime +import logging +import itertools +from pathlib import Path +from functools import partial + +import torch +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from slugify import slugify +from diffusers.loaders import AttnProcsLayers +from diffusers.models.cross_attention import LoRACrossAttnProcessor + +from util import load_config, load_embeddings_from_dir +from data.csv import VlpnDataModule, keyword_filter +from training.functional import train, get_models +from training.lr import plot_metrics +from training.strategy.lora import lora_strategy +from training.optimization import get_scheduler +from training.util import save_args + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A folder containing the training data." + ) + parser.add_argument( + "--train_data_template", + type=str, + default="template", + ) + parser.add_argument( + "--train_set_pad", + type=int, + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." + ) + parser.add_argument( + "--project", + type=str, + default=None, + help="The name of the current project.", + ) + parser.add_argument( + "--exclude_collections", + type=str, + nargs='*', + help="Exclude all items with a listed collection.", + ) + parser.add_argument( + "--num_buckets", + type=int, + default=4, + help="Number of aspect ratio buckets in either direction.", + ) + parser.add_argument( + "--progressive_buckets", + action="store_true", + help="Include images in smaller buckets as well.", + ) + parser.add_argument( + "--bucket_step_size", + type=int, + default=64, + help="Step size between buckets.", + ) + parser.add_argument( + "--bucket_max_pixels", + type=int, + default=None, + help="Maximum pixels per bucket.", + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0.1, + help="Tag dropout probability.", + ) + parser.add_argument( + "--no_tag_shuffle", + action="store_true", + help="Shuffle tags.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=0, + help="How many class images to generate." + ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) + parser.add_argument( + "--collection", + type=str, + nargs='*', + help="A collection to filter the dataset.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=768, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100 + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="one_cycle", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup", "one_cycle"]' + ), + ) + parser.add_argument( + "--lr_warmup_epochs", + type=int, + default=10, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_cycles", + type=int, + default=None, + help="Number of restart cycles in the lr scheduler (if supported)." + ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=3, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=0.04, + help="Minimum learning rate in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--sample_frequency", + type=int, + default=1, + help="How often to save a checkpoint and sample image", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=768, + help="Size of sample images", + ) + parser.add_argument( + "--sample_batches", + type=int, + default=1, + help="Number of sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=1, + help="Times the images in the validation dataset are repeated." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=20, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm." + ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script." + ) + + args = parser.parse_args() + if args.config is not None: + args = load_config(args.config) + args = parser.parse_args(namespace=argparse.Namespace(**args)) + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if args.project is None: + raise ValueError("You must specify --project") + + if isinstance(args.collection, str): + args.collection = [args.collection] + + if isinstance(args.exclude_collections, str): + args.exclude_collections = [args.exclude_collections] + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + if args.seed is None: + args.seed = torch.random.seed() >> 32 + + set_seed(args.seed) + + save_args(output_dir, args) + + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + embeddings.persist() + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + if args.find_lr: + args.learning_rate = 1e-6 + args.lr_scheduler = "exponential_growth" + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + trainer = partial( + train, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + lora_layers=lora_layers, + noise_scheduler=noise_scheduler, + dtype=weight_dtype, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + ) + + checkpoint_output_dir = output_dir.joinpath("model") + sample_output_dir = output_dir.joinpath(f"samples") + + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, + seed=args.seed, + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() + + optimizer = optimizer_class( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e2, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + metrics = trainer( + strategy=lora_strategy, + project="lora", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + seed=args.seed, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + sample_output_dir=sample_output_dir, + checkpoint_output_dir=checkpoint_output_dir, + max_grad_norm=args.max_grad_norm, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + plot_metrics(metrics, output_dir.joinpath("lr.png")) + + +if __name__ == "__main__": + main() diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py @@ -166,7 +166,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -414,7 +414,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay", - default=1e0, + default=1e-2, type=float, help="Embedding decay factor." ) @@ -530,7 +530,7 @@ def main(): vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -612,8 +612,10 @@ def main(): if len(placeholder_tokens) == 1: sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") + metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") else: sample_output_dir = output_dir.joinpath("samples") + metrics_output_file = output_dir.joinpath(f"lr.png") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -687,7 +689,7 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) - plot_metrics(metrics, output_dir.joinpath("lr.png")) + plot_metrics(metrics, metrics_output_file) if args.simultaneous: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): on_prepare: Callable[[], None] = const() - on_model: Callable[[], torch.nn.Module] = const(None) + on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[float, int], None] = const() @@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ) -> Tuple: ... @@ -92,7 +96,6 @@ def save_samples( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - dtype: torch.dtype, output_dir: Path, seed: int, step: int, @@ -107,15 +110,6 @@ def save_samples( grid_cols = min(batch_size, 4) grid_rows = (num_batches * batch_size) // grid_cols - unet = accelerator.unwrap_model(unet) - text_encoder = accelerator.unwrap_model(text_encoder) - - orig_unet_dtype = unet.dtype - orig_text_encoder_dtype = text_encoder.dtype - - unet.to(dtype=dtype) - text_encoder.to(dtype=dtype) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, @@ -172,11 +166,6 @@ def save_samples( image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) - unet.to(dtype=orig_unet_dtype) - text_encoder.to(dtype=orig_text_encoder_dtype) - - del unet - del text_encoder del generator del pipeline @@ -393,7 +382,7 @@ def train_loop( ) global_progress_bar.set_description("Total progress") - model = callbacks.on_model() + model = callbacks.on_accum_model() on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize @@ -559,8 +548,10 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( - accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( + accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) + + kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e88bf90..b4c77f3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, @@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks( else: return nullcontext() - def on_model(): + def on_accum_model(): return unet def on_prepare(): @@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype + + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) + + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) + + del unet_ + del text_encoder_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, - on_model=on_model, + on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -191,9 +206,13 @@ def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ): - return accelerator.prepare(text_encoder, unet, *args) + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/lora.py b/training/strategy/lora.py new file mode 100644 index 0000000..88d1824 --- /dev/null +++ b/training/strategy/lora.py @@ -0,0 +1,147 @@ +from contextlib import nullcontext +from typing import Optional +from functools import partial +from contextlib import contextmanager, nullcontext +from pathlib import Path + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler +from diffusers.loaders import AttnProcsLayers + +from slugify import slugify + +from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import EMAModel +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples + + +def lora_strategy_callbacks( + accelerator: Accelerator, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: MultiCLIPTokenizer, + vae: AutoencoderKL, + sample_scheduler: DPMSolverMultistepScheduler, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + sample_output_dir: Path, + checkpoint_output_dir: Path, + seed: int, + lora_layers: AttnProcsLayers, + max_grad_norm: float = 1.0, + sample_batch_size: int = 1, + sample_num_batches: int = 1, + sample_num_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: Optional[int] = None, +): + sample_output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_output_dir.mkdir(parents=True, exist_ok=True) + + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_samples_ = partial( + save_samples, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + output_dir=sample_output_dir, + seed=seed, + batch_size=sample_batch_size, + num_batches=sample_num_batches, + num_steps=sample_num_steps, + guidance_scale=sample_guidance_scale, + image_size=sample_image_size, + ) + + def on_prepare(): + lora_layers.requires_grad_(True) + + def on_accum_model(): + return unet + + @contextmanager + def on_train(epoch: int): + tokenizer.train() + yield + + @contextmanager + def on_eval(): + tokenizer.eval() + yield + + def on_before_optimize(lr: float, epoch: int): + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) + + @torch.no_grad() + def on_checkpoint(step, postfix): + print(f"Saving checkpoint for step {step}...") + orig_unet_dtype = unet.dtype + unet.to(dtype=torch.float32) + unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) + unet.to(dtype=orig_unet_dtype) + + @torch.no_grad() + def on_sample(step): + orig_unet_dtype = unet.dtype + unet.to(dtype=weight_dtype) + save_samples_(step=step) + unet.to(dtype=orig_unet_dtype) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return TrainingCallbacks( + on_prepare=on_prepare, + on_accum_model=on_accum_model, + on_train=on_train, + on_eval=on_eval, + on_before_optimize=on_before_optimize, + on_checkpoint=on_checkpoint, + on_sample=on_sample, + ) + + +def lora_prepare( + accelerator: Accelerator, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + lora_layers: AttnProcsLayers, + **kwargs +): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} + + +lora_strategy = TrainingStrategy( + callbacks=lora_strategy_callbacks, + prepare=lora_prepare, +) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 14bdafd..d306f18 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, @@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks( else: return nullcontext() - def on_model(): + def on_accum_model(): return text_encoder.text_model.embeddings.temp_token_embedding def on_prepare(): @@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype + + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) + + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) + + del unet_ + del text_encoder_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, - on_model=on_model, + on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -168,7 +183,11 @@ def textual_inversion_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -176,9 +195,10 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - prepped = accelerator.prepare(text_encoder, *args) + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) unet.to(accelerator.device, dtype=weight_dtype) - return (prepped[0], unet) + prepped[1:] + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} textual_inversion_strategy = TrainingStrategy( -- cgit v1.2.3-70-g09d2