From 86e908656bcd7585ec45cd930176800f759f146a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 17:33:00 +0200 Subject: Combined TI with embedding and LoRA --- models/clip/embeddings.py | 19 +++++++++--- train_ti.py | 30 ++++--------------- training/strategy/ti.py | 76 +++++++++++------------------------------------ 3 files changed, 38 insertions(+), 87 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.initializer_factor = config.initializer_factor self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) + self.temp_token_embedding = nn.Embedding( + self.token_embedding.num_embeddings, + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype + ) + self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() self.temp_token_ids = torch.tensor([], dtype=torch.long) def reset_overlay(self): self.overlay.reset() def resize(self, size: int): + self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) + self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( - self.token_embedding.weight.data[self.temp_token_ids] - ) + embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) self.overlay.reset() self.temp_token_ids = torch.tensor([], dtype=torch.long) @@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) embeds = self.token_embedding(input_ids) + mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] += self.overlay(embeds[mask]) + + temp_embeds = self.temp_token_embedding(input_ids[mask]) + embeds[mask] = temp_embeds + self.overlay(temp_embeds) return embeds diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,6 +1,7 @@ import argparse import datetime import logging +import itertools from functools import partial from pathlib import Path import math @@ -306,26 +307,6 @@ def parse_args(): default=0.04, help="Minimum learning rate in the lr scheduler." ) - parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=4/5 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) parser.add_argument( "--optimizer", type=str, @@ -715,10 +696,6 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, @@ -780,7 +757,10 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.overlay.parameters(), + itertools.chain( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), + ), lr=args.learning_rate, ) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 19b8d25..33f5fb9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,6 +1,6 @@ from typing import Optional from functools import partial -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from pathlib import Path import torch @@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer -from training.util import EMAModel from training.functional import TrainingStrategy, TrainingCallbacks, save_samples @@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], gradient_checkpointing: bool = False, - use_ema: bool = False, - ema_inv_gamma: float = 1.0, - ema_power: int = 1, - ema_max_decay: float = 0.9999, sample_batch_size: int = 1, sample_num_batches: int = 1, sample_num_steps: int = 20, @@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks( image_size=sample_image_size, ) - if use_ema: - ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.overlay.parameters(), - inv_gamma=ema_inv_gamma, - power=ema_power, - max_value=ema_max_decay, - ) - ema_embeddings.to(accelerator.device) - else: - ema_embeddings = None - - def ema_context(): - if ema_embeddings is not None: - return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.overlay.parameters() - ) - else: - return nullcontext() - def on_accum_model(): return text_encoder.text_model.embeddings.overlay @@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks( @contextmanager def on_eval(): tokenizer.eval() - - with ema_context(): - yield - - @torch.no_grad() - def on_after_optimize(zero_ids, lr: float): - if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) - - def on_log(): - if ema_embeddings is not None: - return {"ema_decay": ema_embeddings.decay} - return {} + yield @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - with ema_context(): - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" - ) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + ) @torch.no_grad() def on_sample(step): - with ema_context(): - unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) - text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) - orig_unet_dtype = unet_.dtype - orig_text_encoder_dtype = text_encoder_.dtype + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype - unet_.to(dtype=weight_dtype) - text_encoder_.to(dtype=weight_dtype) + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) - unet_.to(dtype=orig_unet_dtype) - text_encoder_.to(dtype=orig_text_encoder_dtype) + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) - del unet_ - del text_encoder_ + del unet_ + del text_encoder_ if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, - on_after_optimize=on_after_optimize, - on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) -- cgit v1.2.3-70-g09d2