From 403f525d0c6900cc6844c0d2f4ecb385fc131969 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 09:40:24 +0100 Subject: Fixed reproducibility, more consistant validation --- data/csv.py | 10 ++++-- train_dreambooth.py | 24 +++++++++++--- train_ti.py | 24 +++++++++++--- training/lora.py | 92 +++++++++++++++++++++++++++++++++++++---------------- training/lr.py | 6 ++-- 5 files changed, 113 insertions(+), 43 deletions(-) diff --git a/data/csv.py b/data/csv.py index af36d9e..e901ab4 100644 --- a/data/csv.py +++ b/data/csv.py @@ -59,7 +59,7 @@ class CSVDataModule(): center_crop: bool = False, template_key: str = "template", valid_set_size: Optional[int] = None, - generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, filter: Optional[Callable[[CSVDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 @@ -84,7 +84,7 @@ class CSVDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size - self.generator = generator + self.seed = seed self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers @@ -155,7 +155,11 @@ class CSVDataModule(): valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size - data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) + generator = torch.Generator(device="cpu") + if self.seed is not None: + generator = generator.manual_seed(self.seed) + + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) diff --git a/train_dreambooth.py b/train_dreambooth.py index df8b54c..6d9bae8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -319,6 +319,12 @@ def parse_args(): default=1e-08, help="Epsilon value for the Adam optimizer" ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) parser.add_argument( "--mixed_precision", type=str, @@ -642,7 +648,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e-4 + args.learning_rate = 1e-6 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -674,6 +680,7 @@ def main(): betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 @@ -730,6 +737,7 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, + seed=args.seed, filter=keyword_filter, collate_fn=collate_fn ) @@ -840,7 +848,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch): + def loop(batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -849,8 +857,14 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + generator=timesteps_gen, + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -1051,7 +1065,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(batch, True) loss = loss.detach_() acc = acc.detach_() diff --git a/train_ti.py b/train_ti.py index 1685dc4..5d6eafc 100644 --- a/train_ti.py +++ b/train_ti.py @@ -288,6 +288,12 @@ def parse_args(): default=1e-08, help="Epsilon value for the Adam optimizer" ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) parser.add_argument( "--mixed_precision", type=str, @@ -592,7 +598,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e-4 + args.learning_rate = 1e-6 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -612,6 +618,7 @@ def main(): betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 @@ -673,6 +680,7 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, + seed=args.seed, filter=keyword_filter, collate_fn=collate_fn ) @@ -791,7 +799,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch): + def loop(batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 @@ -800,8 +808,14 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + generator=timesteps_gen, + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -984,7 +998,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(batch, True) loss = loss.detach_() acc = acc.detach_() diff --git a/training/lora.py b/training/lora.py index e1c0971..3857d78 100644 --- a/training/lora.py +++ b/training/lora.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from diffusers import ModelMixin, ConfigMixin @@ -13,56 +14,93 @@ else: xformers = None -class LoraAttnProcessor(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - cross_attention_dim, - inner_dim, - r: int = 4 - ): +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4): super().__init__() - if r > min(cross_attention_dim, inner_dim): + if rank > min(in_features, out_features): raise ValueError( - f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" + f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" ) - self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) - self.lora_k_up = nn.Linear(r, inner_dim, bias=False) + self.lora_down = nn.Linear(in_features, rank, bias=False) + self.lora_up = nn.Linear(rank, out_features, bias=False) + self.scale = 1.0 - self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) - self.lora_v_up = nn.Linear(r, inner_dim, bias=False) + nn.init.normal_(self.lora_down.weight, std=1 / rank) + nn.init.zeros_(self.lora_up.weight) - self.scale = 1.0 + def forward(self, hidden_states): + down_hidden_states = self.lora_down(hidden_states) + up_hidden_states = self.lora_up(down_hidden_states) - nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_k_up.weight) + return up_hidden_states - nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) - nn.init.zeros_(self.lora_v_up.weight) - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape +class LoRACrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale - value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class LoRAXFormersCrossAttnProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim, rank=4): + super().__init__() + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) + + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/training/lr.py b/training/lr.py index 37588b6..a3144ba 100644 --- a/training/lr.py +++ b/training/lr.py @@ -1,6 +1,6 @@ import math import copy -from typing import Callable +from typing import Callable, Any, Tuple, Union from functools import partial import matplotlib.pyplot as plt @@ -24,7 +24,7 @@ class LRFinder(): optimizer, train_dataloader, val_dataloader, - loss_fn, + loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], None] = noop, on_eval: Callable[[], None] = noop ): @@ -108,7 +108,7 @@ class LRFinder(): if step >= num_val_batches: break - loss, acc, bsz = self.loss_fn(batch) + loss, acc, bsz = self.loss_fn(batch, True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-54-g00ecf