From 208e48134e324e934ad964bdc61880cc923f4c0d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 22:13:55 +0200 Subject: Revert --- models/clip/embeddings.py | 42 ++----------------- train_ti.py | 52 +++++++++++++++++++++--- training/functional.py | 2 +- training/strategy/ti.py | 100 +++++++++++++++++++++++++++++++++++++--------- 4 files changed, 132 insertions(+), 64 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index c9c788c..1e21965 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi return new_embedding -class OverlayLinear(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.rank = rank - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - self.reset() - - def reset(self): - nn.init.normal_(self.down.weight, std=1 / self.rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - return up_hidden_states.to(orig_dtype) - - class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): super().__init__(config) self.token_embedding = embeddings.token_embedding self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor + self.alpha = alpha - self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, @@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) - def reset_overlay(self): - self.overlay.reset() - def resize(self, size: int): self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) @@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] - self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) - self.overlay.reset() + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): @@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - - temp_embeds = self.temp_token_embedding(input_ids[mask]) - embeds[mask] = temp_embeds + self.overlay(temp_embeds) + embeds[mask] = self.temp_token_embedding(input_ids[mask]) return embeds diff --git a/train_ti.py b/train_ti.py index 26ac384..5482326 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,7 +1,6 @@ import argparse import datetime import logging -import itertools from functools import partial from pathlib import Path import math @@ -307,6 +306,26 @@ def parse_args(): default=0.04, help="Minimum learning rate in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=4/5 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--optimizer", type=str, @@ -334,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( @@ -431,6 +450,23 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e2, + type=float, + help="Embedding decay factor." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -696,6 +732,13 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, @@ -757,10 +800,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - itertools.chain( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - text_encoder.text_model.embeddings.overlay.parameters(), - ), + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.learning_rate, ) diff --git a/training/functional.py b/training/functional.py index 7104a88..bd8cbad 100644 --- a/training/functional.py +++ b/training/functional.py @@ -524,7 +524,7 @@ def train_loop( lr = lr_scheduler.get_last_lr()[0] if torch.is_tensor(lr): - lr = lr.item + lr = lr.item() lrs.append(lr) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 1b5adab..677f5a3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,6 +1,6 @@ from typing import Optional from functools import partial -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from pathlib import Path import torch @@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, + use_emb_decay: bool = False, + emb_decay_target: float = 0.4, + emb_decay: float = 1e-2, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: int = 1, + ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, @@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks( image_size=sample_image_size, ) + if use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + ema_embeddings.to(accelerator.device) + else: + ema_embeddings = None + + def ema_context(): + if ema_embeddings is not None: + return ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) + else: + return nullcontext() + def on_accum_model(): - return text_encoder.text_model.embeddings + return text_encoder.text_model.embeddings.temp_token_embedding @contextmanager def on_train(epoch: int): @@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks( @contextmanager def on_eval(): tokenizer.eval() - yield + + with ema_context(): + yield + + @torch.no_grad() + def on_before_optimize(lr: float, epoch: int): + if use_emb_decay: + w = text_encoder.text_model.embeddings.temp_token_embedding.weight + return torch.all(w.grad == 0, dim=1) + + @torch.no_grad() + def on_after_optimize(zero_ids, lr: float): + if ema_embeddings is not None: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + if use_emb_decay: + lambda_ = emb_decay * lr + + if lambda_ != 0: + w = text_encoder.text_model.embeddings.temp_token_embedding.weight + + mask = torch.ones(w.shape[0], dtype=torch.bool) + mask[zero_ids] = False + + norm = w[mask, :].norm(dim=-1, keepdim=True) + w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + + def on_log(): + if ema_embeddings is not None: + return {"ema_decay": ema_embeddings.decay} + return {} @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" - ) + with ema_context(): + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + ) @torch.no_grad() def on_sample(step): - unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) - text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + with ema_context(): + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) - orig_unet_dtype = unet_.dtype - orig_text_encoder_dtype = text_encoder_.dtype + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype - unet_.to(dtype=weight_dtype) - text_encoder_.to(dtype=weight_dtype) + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) - unet_.to(dtype=orig_unet_dtype) - text_encoder_.to(dtype=orig_text_encoder_dtype) + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) - del unet_ - del text_encoder_ + del unet_ + del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, + on_after_optimize=on_after_optimize, + on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) -- cgit v1.2.3-70-g09d2