From d952d467d31786f4a85cc4cb009934cd4ebbba71 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 09:09:46 +0200 Subject: Update --- infer.py | 8 +- .../stable_diffusion/vlpn_stable_diffusion.py | 2 +- train_lora.py | 303 ++++++++++++++++++--- train_ti.py | 6 +- training/strategy/lora.py | 37 ++- 5 files changed, 305 insertions(+), 51 deletions(-) diff --git a/infer.py b/infer.py index 93848d7..8fdf63d 100644 --- a/infer.py +++ b/infer.py @@ -67,8 +67,8 @@ default_cmds = { "batch_num": 1, "steps": 30, "guidance_scale": 7.0, - "sag_scale": 0.75, - "lora_scale": 0.5, + "sag_scale": 0, + "brightness_offset": 0, "seed": None, "config": None, } @@ -192,7 +192,7 @@ def create_cmd_parser(): type=float, ) parser.add_argument( - "--lora_scale", + "--brightness_offset", type=float, ) parser.add_argument( @@ -392,7 +392,7 @@ def generate(output_dir: Path, pipeline, args): generator=generator, image=init_image, strength=args.image_noise, - # cross_attention_kwargs={"scale": args.lora_scale}, + brightness_offset=args.brightness_offset, ).images for j, image in enumerate(images): diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 127ca50..cfc3208 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -403,7 +403,7 @@ class VlpnStableDiffusion(DiffusionPipeline): width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, - sag_scale: float = 0.75, + sag_scale: float = 0.0, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, diff --git a/train_lora.py b/train_lora.py index 1ca56d9..39bf455 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1,6 +1,7 @@ import argparse import datetime import logging +import itertools from pathlib import Path from functools import partial import math @@ -17,9 +18,10 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, get_models +from training.functional import train, add_placeholder_tokens, get_models from training.lr import plot_metrics from training.strategy.lora import lora_strategy +from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -80,6 +82,43 @@ def parse_args(): default=None, help="The name of the current project.", ) + parser.add_argument( + "--placeholder_tokens", + type=str, + nargs='*', + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word" + ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs='*', + default=[], + help="Tokens to create an alias for." + ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--exclude_collections", type=str, @@ -186,6 +225,16 @@ def parse_args(): type=int, default=2000 ) + parser.add_argument( + "--num_pti_epochs", + type=int, + default=None + ) + parser.add_argument( + "--num_pti_steps", + type=int, + default=500 + ) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -257,6 +306,12 @@ def parse_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--learning_rate_pti", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -432,6 +487,23 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e+2, + type=float, + help="Embedding decay factor." + ) parser.add_argument( "--max_grad_norm", default=1.0, @@ -464,6 +536,40 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] + + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") + + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + + if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) + + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + + if args.alias_tokens is None: + args.alias_tokens = [] + + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -544,6 +650,19 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens + ) + embeddings.persist() + print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): @@ -552,6 +671,19 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors, + initializer_noise=args.initializer_noise, + ) + stats = list(zip( + args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids + )) + print(f"Training embeddings: {stats}") + if args.scale_lr: args.learning_rate_unet = ( args.learning_rate_unet * args.gradient_accumulation_steps * @@ -561,10 +693,15 @@ def main(): args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + args.learning_rate_pti = ( + args.learning_rate_pti * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 + args.learning_rate_pti = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': @@ -663,18 +800,25 @@ def main(): accelerator=accelerator, unet=unet, text_encoder=text_encoder, + tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, + seed=args.seed, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, + gradient_accumulation_steps=args.gradient_accumulation_steps, + offset_noise_strength=args.offset_noise_strength, + sample_scheduler=sample_scheduler, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, ) - checkpoint_output_dir = output_dir / "model" - sample_output_dir = output_dir/"samples" - - datamodule = VlpnDataModule( + create_datamodule = partial( + VlpnDataModule, data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, @@ -693,71 +837,146 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, + dtype=weight_dtype, + ) + + create_lr_scheduler = partial( + get_scheduler, + args.lr_scheduler, + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e2, + warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, + ) + + # PTI + # -------------------------------------------------------------------------------- + + pti_output_dir = output_dir / "pti" + pti_checkpoint_output_dir = pti_output_dir / "model" + pti_sample_output_dir = pti_output_dir / "samples" + + pti_datamodule = create_datamodule( + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + ) + pti_datamodule.setup() + + num_pti_epochs = args.num_pti_epochs + pti_sample_frequency = args.sample_frequency + if num_pti_epochs is None: + num_pti_epochs = math.ceil( + args.num_pti_steps / len(pti_datamodule.train_dataset) + ) * args.gradient_accumulation_steps + pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) + + pti_optimizer = create_optimizer( + [ + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_pti, + "weight_decay": 0, + }, + ] + ) + + pti_lr_scheduler = create_lr_scheduler( + optimizer=pti_optimizer, + num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), + train_epochs=num_pti_epochs, + ) + + metrics = trainer( + strategy=textual_inversion_strategy, + project="ti", + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_pti_epochs, + # -- + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + ) + + plot_metrics(metrics, output_dir/"lr.png") + + # LORA + # -------------------------------------------------------------------------------- + + lora_output_dir = output_dir / "pti" + lora_checkpoint_output_dir = lora_output_dir / "model" + lora_sample_output_dir = lora_output_dir / "samples" + + lora_datamodule = create_datamodule( filter=partial(keyword_filter, None, args.collection, args.exclude_collections), - dtype=weight_dtype ) - datamodule.setup() + lora_datamodule.setup() num_train_epochs = args.num_train_epochs - sample_frequency = args.sample_frequency + lora_sample_frequency = args.sample_frequency if num_train_epochs is None: num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) + args.num_train_steps / len(lora_datamodule.train_dataset) ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) + lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) - optimizer = create_optimizer( + lora_optimizer = create_optimizer( [ { "params": unet.parameters(), "lr": args.learning_rate_unet, }, { - "params": text_encoder.parameters(), + "params": itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), + "lr": args.learning_rate_text, + }, + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), "lr": args.learning_rate_text, + "weight_decay": 0, }, ] ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - end_lr=1e2, + lora_lr_scheduler = create_lr_scheduler( + optimizer=lora_optimizer, + num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - mid_point=args.lr_mid_point, ) metrics = trainer( strategy=lora_strategy, project="lora", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, + train_dataloader=lora_datamodule.train_dataloader, + val_dataloader=lora_datamodule.val_dataloader, + optimizer=lora_optimizer, + lr_scheduler=lora_lr_scheduler, num_train_epochs=num_train_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=sample_frequency, - offset_noise_strength=args.offset_noise_strength, # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, + sample_output_dir=lora_sample_output_dir, + checkpoint_output_dir=lora_checkpoint_output_dir, + sample_frequency=lora_sample_frequency, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, max_grad_norm=args.max_grad_norm, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, ) plot_metrics(metrics, output_dir/"lr.png") diff --git a/train_ti.py b/train_ti.py index fc0d68c..344b412 100644 --- a/train_ti.py +++ b/train_ti.py @@ -750,6 +750,7 @@ def main(): accelerator=accelerator, unet=unet, text_encoder=text_encoder, + tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, @@ -764,8 +765,6 @@ def main(): global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, @@ -774,6 +773,7 @@ def main(): ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, ema_max_decay=args.ema_max_decay, + sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, @@ -863,9 +863,9 @@ def main(): optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, - sample_frequency=sample_frequency, # -- sample_output_dir=sample_output_dir, + sample_frequency=sample_frequency, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 209785a..d51a2f3 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -14,6 +14,8 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from peft import get_peft_model_state_dict from safetensors.torch import save_file +from slugify import slugify + from models.clip.tokenizer import MultiCLIPTokenizer from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -30,6 +32,11 @@ def lora_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, max_grad_norm: float = 1.0, sample_batch_size: int = 1, sample_num_batches: int = 1, @@ -77,6 +84,22 @@ def lora_strategy_callbacks( max_grad_norm ) + if use_emb_decay: + return torch.stack([ + p + for p in text_encoder.text_model.embeddings.token_override_embedding.params + if p.grad is not None + ]) + + @torch.no_grad() + def on_after_optimize(w, lr: float): + if use_emb_decay: + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + @torch.no_grad() def on_checkpoint(step, postfix): if postfix != "end": @@ -87,6 +110,12 @@ def lora_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + text_encoder_.text_model.embeddings.save_embed( + ids, + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + ) + lora_config = {} state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) @@ -126,6 +155,7 @@ def lora_strategy_callbacks( on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, on_checkpoint=on_checkpoint, on_sample=on_sample, ) @@ -141,7 +171,12 @@ def lora_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + + text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) + + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} lora_strategy = TrainingStrategy( -- cgit v1.2.3-70-g09d2