From ab24e5cbd8283ad4ced486e1369484ebf9e3962d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Apr 2023 16:06:04 +0200 Subject: Update --- infer.py | 39 ++++++++++++++++++++++++++++++++++++-- train_lora.py | 51 +++++++++++++++++++++++++++++++------------------- training/functional.py | 40 ++++----------------------------------- 3 files changed, 73 insertions(+), 57 deletions(-) diff --git a/infer.py b/infer.py index ed86ab1..93848d7 100644 --- a/infer.py +++ b/infer.py @@ -26,6 +26,8 @@ from diffusers import ( DEISMultistepScheduler, UniPCMultistepScheduler ) +from peft import LoraConfig, LoraModel, set_peft_model_state_dict +from safetensors.torch import load_file from transformers import CLIPTextModel from data.keywords import str_to_keywords, keywords_to_str @@ -43,7 +45,7 @@ default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", - "lora_embeddings_dir": "embeddings_lora", + "lora_embedding": None, "output_dir": "output/inference", "config": None, } @@ -99,7 +101,7 @@ def create_args_parser(): type=str, ) parser.add_argument( - "--lora_embeddings_dir", + "--lora_embedding", type=str, ) parser.add_argument( @@ -236,6 +238,38 @@ def load_embeddings(pipeline, embeddings_dir): print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") +def load_lora(pipeline, path): + if path is None: + return + + path = Path(path) + + with open(path / "lora_config.json", "r") as f: + lora_config = json.load(f) + + tensor_files = list(path.glob("*_end.safetensors")) + + if len(tensor_files) == 0: + return + + lora_checkpoint_sd = load_file(path / tensor_files[0]) + unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} + text_encoder_lora_ds = { + k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k + } + + unet_config = LoraConfig(**lora_config["peft_config"]) + pipeline.unet = LoraModel(unet_config, pipeline.unet) + set_peft_model_state_dict(pipeline.unet, unet_lora_ds) + + if "text_encoder_peft_config" in lora_config: + text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) + pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) + set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) + + return + + def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): if scheduler == "plms": return PNDMScheduler.from_config(config) @@ -441,6 +475,7 @@ def main(): pipeline = create_pipeline(args.model, dtype) load_embeddings(pipeline, args.ti_embeddings_dir) + load_lora(pipeline, args.lora_embedding) # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() diff --git a/train_lora.py b/train_lora.py index 73b3e19..1ca56d9 100644 --- a/train_lora.py +++ b/train_lora.py @@ -1,7 +1,6 @@ import argparse import datetime import logging -import itertools from pathlib import Path from functools import partial import math @@ -247,9 +246,15 @@ def parse_args(): help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate", + "--learning_rate_unet", type=float, - default=2e-6, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", + type=float, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -548,13 +553,18 @@ def main(): print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * + args.learning_rate_unet = ( + args.learning_rate_unet * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + args.learning_rate_text = ( + args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate_unet = 1e-6 + args.learning_rate_text = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == 'adam8bit': @@ -611,8 +621,8 @@ def main(): ) args.lr_scheduler = "adafactor" - args.lr_min_lr = args.learning_rate - args.learning_rate = None + args.lr_min_lr = args.learning_rate_unet + args.learning_rate_unet = None elif args.optimizer == 'dadam': try: import dadaptation @@ -628,7 +638,8 @@ def main(): d0=args.dadaptation_d0, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 elif args.optimizer == 'dadan': try: import dadaptation @@ -642,7 +653,8 @@ def main(): d0=args.dadaptation_d0, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -695,15 +707,16 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - ( - param - for param in itertools.chain( - unet.parameters(), - text_encoder.parameters(), - ) - if param.requires_grad - ), - lr=args.learning_rate, + [ + { + "params": unet.parameters(), + "lr": args.learning_rate_unet, + }, + { + "params": text_encoder.parameters(), + "lr": args.learning_rate_text, + }, + ] ) lr_scheduler = get_scheduler( diff --git a/training/functional.py b/training/functional.py index 06848cb..c30d1c0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -321,45 +321,13 @@ def loss_step( ) if offset_noise_strength != 0: - solid_image = partial( - make_solid_image, - shape=images.shape[1:], - vae=vae, + offset_noise = torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), dtype=latents.dtype, device=latents.device, generator=generator - ) - - white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" - black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" - - if white_cache_key not in cache: - img_white = solid_image(1) - cache[white_cache_key] = img_white - else: - img_white = cache[white_cache_key] - - if black_cache_key not in cache: - img_black = solid_image(0) - cache[black_cache_key] = img_black - else: - img_black = cache[black_cache_key] - - offset_strength = torch.rand( - (bsz, 1, 1, 1), - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - offset_strength = offset_noise_strength * (offset_strength * 2 - 1) - offset_images = torch.where( - offset_strength >= 0, - img_white.expand(noise.shape), - img_black.expand(noise.shape) - ) - offset_strength = offset_strength.abs().expand(noise.shape) - noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) + ).expand(noise.shape) + noise += offset_noise_strength * offset_noise # Sample a random timestep for each image timesteps = torch.randint( -- cgit v1.2.3-70-g09d2