From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- models/clip/embeddings.py | 76 ++++-------- models/lora.py | 131 +++++++++++++++++++++ models/sparse.py | 66 ----------- .../stable_diffusion/vlpn_stable_diffusion.py | 20 +++- train_lora.py | 7 +- train_ti.py | 28 ++++- training/functional.py | 11 +- training/strategy/lora.py | 4 +- training/strategy/ti.py | 9 +- 9 files changed, 212 insertions(+), 140 deletions(-) create mode 100644 models/lora.py delete mode 100644 models/sparse.py diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -11,49 +11,27 @@ from transformers import CLIPTextModel from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -from models.sparse import PseudoSparseEmbedding - - -def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: - old_num_embeddings, old_embedding_dim = old_embedding.weight.shape - - if old_num_embeddings == new_num_embeddings: - return old_embedding - - n = min(old_num_embeddings, new_num_embeddings) - - new_embedding = nn.Embedding( - new_num_embeddings, - old_embedding_dim, - device=old_embedding.weight.device, - dtype=old_embedding.weight.dtype - ) - if initializer_factor is not None: - new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) - else: - nn.init.zeros_(new_embedding.weight.data) - new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] - return new_embedding +from models.lora import LoraEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): super().__init__(config) - self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - - self.token_override_embedding = PseudoSparseEmbedding( + self.token_embedding = LoraEmbedding( + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, - dropout_p=dropout_p, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype, + r, + lora_alpha, + lora_dropout, ) + self.token_embedding.weight = embeddings.token_embedding.weight + def resize(self, size: int): - self.token_override_embedding.resize(size) - self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) + self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) def add_embed( self, @@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.token_embedding.weight.data[token_ids] = initializer - self.token_override_embedding.set(token_ids, initializer) def load_embed(self, input_ids: list[int], filename: Path): with safe_open(filename, framework="pt", device="cpu") as file: @@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange( - self.token_embedding.num_embeddings, - device=self.token_override_embedding.mapping.device - ) - embs, mask = self.token_override_embedding(input_ids) - if embs is not None: - input_ids = input_ids[mask] - self.token_embedding.weight.data[input_ids] = embs - self.token_override_embedding.unset(input_ids) + self.token_embedding.eval() + self.token_embedding.merged = False def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) - embs = self.token_embedding(input_ids) - embs_override, mask = self.token_override_embedding(input_ids) - if embs_override is not None: - embs[mask] = embs_override - - return embs + return self.token_embedding(input_ids) def forward( self, @@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) +def patch_managed_embeddings( + text_encoder: CLIPTextModel, + r: int = 8, + lora_alpha: int = 8, + lora_dropout: float = 0.0 +) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings( + text_encoder.config, + text_encoder.text_model.embeddings, + r, + lora_alpha, + lora_dropout + ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py @@ -0,0 +1,131 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LoraLayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + self.lora_dropout_p = lora_dropout + + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = nn.Identity() + + self.merged = False + self.merge_weights = merge_weights + + +class LoraEmbedding(nn.Embedding, LoraLayer): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + merge_weights: bool = True, + **kwargs + ): + nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) + LoraLayer.__init__( + self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights + ) + + self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) + self.trainable_ids -= 1 + + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) + self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.scaling = self.lora_alpha / self.r + self.weight.requires_grad = False + + self.reset_parameters() + + def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + n = min(self.num_embeddings, new_num_embeddings) + + new_emb = LoraEmbedding( + new_num_embeddings, + self.embedding_dim, + self.r, + self.lora_alpha, + self.lora_dropout_p, + device=self.weight.device, + dtype=self.weight.dtype + ) + if initializer_factor is not None: + new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + else: + nn.init.zeros_(new_emb.weight.data) + new_emb.weight.data[:n, :] = self.weight.data[:n, :] + new_emb.lora_A = self.lora_A + new_emb.lora_B = self.lora_B + new_emb.trainable_ids[:n] = self.trainable_ids[:n] + + return new_emb + + def mark_trainable(self, input_ids): + trainable_ids = self.trainable_ids[input_ids] + new_ids = trainable_ids[trainable_ids == -1] + + if new_ids.shape[0] == 0: + return + + n = self.trainable_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + + lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) + lora_A.data[:n] = self.lora_A.data + self.lora_A = lora_A + + def reset_parameters(self): + nn.Embedding.reset_parameters(self) + if hasattr(self, 'lora_A'): + nn.init.zeros_(self.lora_A) + nn.init.normal_(self.lora_B) + + def train(self, mode: bool = True): + nn.Embedding.train(self, mode) + if self.merge_weights and self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling + self.merged = False + + def eval(self): + nn.Embedding.eval(self) + if self.merge_weights and not self.merged: + if self.r > 0: + mask = ~(self.trainable_ids == -1) + trainable_ids = self.trainable_ids[mask] + self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, input_ids: torch.Tensor): + result = nn.Embedding.forward(self, input_ids) + + if self.r > 0 and not self.merged: + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + trainable_ids = trainable_ids[mask] + + after_A = F.embedding( + trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + result[mask] += (after_A @ self.lora_B.T) * self.scaling + + return result diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - - -class PseudoSparseEmbedding(nn.Module): - def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): - super().__init__() - - self.embedding_dim = embedding_dim - self.dtype = dtype - self.params = nn.ParameterList() - - if dropout_p > 0.0: - self.dropout = nn.Dropout(p=dropout_p) - else: - self.dropout = nn.Identity() - - self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) - - def forward(self, input_ids: torch.LongTensor): - input_ids = input_ids.to(self.mapping.device) - ids = self.mapping[input_ids] - mask = ~(ids == -1) - - if torch.all(~mask): - embs = None - else: - embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) - - return embs, mask - - def resize(self, new_num_embeddings: int): - old_num_embeddings = self.mapping.shape[0] - n = min(old_num_embeddings, new_num_embeddings) - - new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 - new_mapping[:n] = self.mapping[:n] - - self.mapping = new_mapping - - def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): - if len(input_ids.shape) != 0: - if tensor is not None: - return [self.set(id, t) for id, t in zip(input_ids, tensor)] - else: - return [self.set(id) for id in input_ids] - - if tensor is None: - tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) - - if tensor.shape[-1] != self.embedding_dim: - raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") - - id = self.mapping[input_ids] - - if id == -1: - id = len(self.params) - self.mapping[input_ids] = id - self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) - - self.params[id] = tensor - - def unset(self, input_ids: torch.LongTensor): - self.mapping[input_ids] = -1 diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 13ea2ac..a0dff54 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker has_nsfw_concept = None - # 11. Convert to PIL - if output_type == "pil": + if output_type == "latent": + image = latents + elif output_type == "pil": + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL image = self.numpy_to_pil(image) + else: + # 9. Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) diff --git a/train_lora.py b/train_lora.py index b8c7396..91bda5c 100644 --- a/train_lora.py +++ b/train_lora.py @@ -387,7 +387,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="dadan", + default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) @@ -412,7 +412,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=2e-2, help="Weight decay to use." ) parser.add_argument( @@ -780,6 +780,7 @@ def main(): timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + no_prox=True, ) elif args.optimizer == 'lion': try: @@ -961,7 +962,7 @@ def main(): if len(args.placeholder_tokens) != 0: params_to_optimize.append({ - "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), + "params": text_encoder.text_model.embeddings.token_embedding.parameters(), "lr": learning_rate_emb, "weight_decay": 0, }) diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -18,7 +18,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from models.convnext.discriminator import ConvNeXtDiscriminator from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler @@ -354,7 +353,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="dadan", + default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) @@ -379,7 +378,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=2e-2, help="Weight decay to use." ) parser.add_argument( @@ -483,7 +482,19 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--emb_dropout", + "--lora_r", + type=int, + default=8, + help="Lora rank, only used if use_lora is True" + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=32, + help="Lora alpha, only used if use_lora is True" + ) + parser.add_argument( + "--lora_dropout", type=float, default=0, help="Embedding dropout probability.", @@ -655,7 +666,11 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path, args.emb_dropout) + args.pretrained_model_name_or_path, + args.lora_r, + args.lora_alpha, + args.lora_dropout + ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -747,6 +762,7 @@ def main(): timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + no_prox=True, ) elif args.optimizer == 'lion': try: @@ -914,7 +930,7 @@ def main(): print("") optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), lr=learning_rate, ) diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py @@ -66,7 +66,12 @@ class TrainingStrategy(): prepare: TrainingStrategyPrepareCallable -def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): +def get_models( + pretrained_model_name_or_path: str, + emb_r: int = 8, + emb_lora_alpha: int = 8, + emb_lora_dropout: float = 0.0 +): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder, emb_dropout) + embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings @@ -653,6 +658,8 @@ def train_loop( on_checkpoint(global_step, "end") raise KeyboardInterrupt + return avg_loss, avg_acc, avg_loss_val, avg_acc_val + def train( accelerator: Accelerator, diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -93,7 +93,7 @@ def lora_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -180,7 +180,7 @@ def lora_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) - text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), + text_encoder.text_model.embeddings.token_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.token_override_embedding.parameters() + text_encoder.text_model.embeddings.token_embedding.parameters() ) else: return nullcontext() @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() + for p in text_encoder.text_model.embeddings.token_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -203,7 +203,6 @@ def textual_inversion_prepare( text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2