From 47b2fba05ba0f2f5335321d94e0634a7291980c5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Dec 2022 16:48:25 +0100 Subject: Some LoRA fixes (still broken) --- train_lora.py | 159 ++++++++++++------------------------------------------- training/lora.py | 68 +++++++++++++++++++----- 2 files changed, 89 insertions(+), 138 deletions(-) diff --git a/train_lora.py b/train_lora.py index ffc1d10..34e1008 100644 --- a/train_lora.py +++ b/train_lora.py @@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel -from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -24,8 +23,9 @@ from slugify import slugify from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule -from training.lora import LoraAttention +from training.lora import LoraAttnProcessor from training.optimization import get_one_cycle_schedule +from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -109,7 +109,7 @@ def parse_args(): parser.add_argument( "--output_dir", type=str, - default="output/dreambooth", + default="output/lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -176,7 +176,7 @@ def parse_args(): help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( - "--learning_rate_unet", + "--learning_rate", type=float, default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", @@ -348,76 +348,45 @@ def parse_args(): return args -def save_args(basepath: Path, args, extra={}): - info = {"args": vars(args)} - info["args"].update(extra) - with open(basepath.joinpath("args.json"), "w") as f: - json.dump(info, f, indent=4) - - -def freeze_params(params): - for param in params: - param.requires_grad = False - - -def make_grid(images, rows, cols): - w, h = images[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - for i, image in enumerate(images): - grid.paste(image, box=(i % cols*w, i//cols*h)) - return grid - - -class AverageMeter: - def __init__(self, name=None): - self.name = name - self.reset() - - def reset(self): - self.sum = self.count = self.avg = 0 - - def update(self, val, n=1): - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -class Checkpointer: +class Checkpointer(CheckpointerBase): def __init__( self, datamodule, accelerator, vae, unet, - unet_lora, tokenizer, text_encoder, + unet_lora, scheduler, - output_dir: Path, instance_identifier, placeholder_token, placeholder_token_id, + output_dir: Path, sample_image_size, sample_batches, sample_batch_size, seed ): - self.datamodule = datamodule + super().__init__( + datamodule=datamodule, + output_dir=output_dir, + instance_identifier=instance_identifier, + placeholder_token=placeholder_token, + placeholder_token_id=placeholder_token_id, + sample_image_size=sample_image_size, + seed=seed or torch.random.seed(), + sample_batches=sample_batches, + sample_batch_size=sample_batch_size + ) + self.accelerator = accelerator self.vae = vae self.unet = unet - self.unet_lora = unet_lora self.tokenizer = tokenizer self.text_encoder = text_encoder + self.unet_lora = unet_lora self.scheduler = scheduler - self.output_dir = output_dir - self.instance_identifier = instance_identifier - self.placeholder_token = placeholder_token - self.placeholder_token_id = placeholder_token_id - self.sample_image_size = sample_image_size - self.seed = seed or torch.random.seed() - self.sample_batches = sample_batches - self.sample_batch_size = sample_batch_size @torch.no_grad() def save_model(self): @@ -433,83 +402,18 @@ class Checkpointer: @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - samples_path = Path(self.output_dir).joinpath("samples") - - unet = self.accelerator.unwrap_model(self.unet) - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - + # Save a sample image pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, + text_encoder=self.text_encoder, vae=self.vae, - unet=unet, + unet=self.unet, tokenizer=self.tokenizer, scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - train_data = self.datamodule.train_dataloader() - val_data = self.datamodule.val_dataloader() + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), - device=pipeline.device, - generator=generator, - ) - - with torch.autocast("cuda"), torch.inference_mode(): - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] - prompts = [ - prompt.format(identifier=self.instance_identifier) - for batch in batches - for prompt in batch["prompts"] - ] - nprompts = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ] - - for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - image=latents[:len(prompt)] if latents is not None else None, - generator=generator if latents is not None else None, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid - - del unet - del text_encoder del pipeline del generator del stable_latents @@ -558,7 +462,11 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - unet_lora = LoraAttention() + unet_lora = LoraAttnProcessor( + cross_attention_dim=unet.cross_attention_dim, + inner_dim=unet.in_channels, + r=4, + ) vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) @@ -618,8 +526,8 @@ def main(): prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: - args.learning_rate_unet = ( - args.learning_rate_unet * args.gradient_accumulation_steps * + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) @@ -639,7 +547,7 @@ def main(): [ { 'params': unet_lora.parameters(), - 'lr': args.learning_rate_unet, + 'lr': args.learning_rate, }, ], betas=(args.adam_beta1, args.adam_beta2), @@ -801,7 +709,7 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) - accelerator.init_trackers("dreambooth", config=config) + accelerator.init_trackers("lora", config=config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -832,6 +740,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, + unet_lora=unet_lora, output_dir=basepath, instance_identifier=instance_identifier, placeholder_token=args.placeholder_token, diff --git a/training/lora.py b/training/lora.py index d8dc147..e1c0971 100644 --- a/training/lora.py +++ b/training/lora.py @@ -1,27 +1,69 @@ import torch.nn as nn -from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config +from diffusers import ModelMixin, ConfigMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.cross_attention import CrossAttention +from diffusers.utils.import_utils import is_xformers_available -class LoraAttention(ModelMixin, ConfigMixin): + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class LoraAttnProcessor(ModelMixin, ConfigMixin): @register_to_config - def __init__(self, in_features, out_features, r=4): + def __init__( + self, + cross_attention_dim, + inner_dim, + r: int = 4 + ): super().__init__() - if r > min(in_features, out_features): + if r > min(cross_attention_dim, inner_dim): raise ValueError( - f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" ) - self.lora_down = nn.Linear(in_features, r, bias=False) - self.lora_up = nn.Linear(r, out_features, bias=False) + self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) + self.lora_k_up = nn.Linear(r, inner_dim, bias=False) + + self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) + self.lora_v_up = nn.Linear(r, inner_dim, bias=False) + self.scale = 1.0 - self.processor = XFormersCrossAttnProcessor() + nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) + nn.init.zeros_(self.lora_k_up.weight) + + nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) + nn.init.zeros_(self.lora_v_up.weight) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale + value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) - nn.init.normal_(self.lora_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_up.weight) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): - hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) - hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale return hidden_states -- cgit v1.2.3-70-g09d2