From e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 Mon Sep 17 00:00:00 2001 From: Volpeon <git@volpeon.ink> Date: Mon, 3 Apr 2023 18:52:30 +0200 Subject: TI: Delta learning --- models/clip/embeddings.py | 50 +++++++++++++++++++++++++++++++---------------- train_ti.py | 37 +++++++++++------------------------ training/functional.py | 4 ++-- training/strategy/ti.py | 23 ---------------------- 4 files changed, 46 insertions(+), 68 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1e21965..d8343a0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -12,7 +12,7 @@ from transformers.models.clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings -def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: +def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: old_num_embeddings, old_embedding_dim = old_embedding.weight.shape if old_num_embeddings == new_num_embeddings: @@ -26,13 +26,16 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi device=old_embedding.weight.device, dtype=old_embedding.weight.dtype ) - new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + if initializer_factor is not None: + new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) + else: + nn.init.zeros_(new_embedding.weight.data) new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] return new_embedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): super().__init__(config) self.token_embedding = embeddings.token_embedding @@ -40,17 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.initializer_factor = config.initializer_factor self.alpha = alpha - self.temp_token_embedding = nn.Embedding( - self.token_embedding.num_embeddings, - self.token_embedding.embedding_dim, - device=self.token_embedding.weight.device, - dtype=self.token_embedding.weight.dtype - ) - self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() + self.temp_token_embedding = nn.ParameterList() self.temp_token_ids = torch.tensor([], dtype=torch.long) def resize(self, size: int): - self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) + for _ in range(len(self.temp_token_embedding), size): + self.temp_token_embedding.append(torch.zeros( + self.token_embedding.embedding_dim, + device=self.token_embedding.weight.device, + dtype=self.token_embedding.weight.dtype, + )) self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) def add_embed( @@ -85,7 +87,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) - self.temp_token_embedding.weight.data[token_ids] = initializer self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): @@ -96,16 +97,31 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): + self.token_embedding.weight.data[id] += self.alpha * emb + nn.init.zeros_(emb) self.temp_token_ids = torch.tensor([], dtype=torch.long) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) + all_temp_token_ids = self.temp_token_ids.to(input_ids.device) + embeds = self.token_embedding(input_ids) - mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) - embeds[mask] = self.temp_token_embedding(input_ids[mask]) + mask = torch.isin(input_ids, all_temp_token_ids) + temp_token_ids = input_ids[mask] + + temp_token_ids = temp_token_ids.unsqueeze(1) + all_temp_token_ids = all_temp_token_ids.unsqueeze(0) + temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() + + if len(temp_token_ids): + embeds_override = torch.stack([ + self.temp_token_embedding[id] + for id in temp_token_ids + ]) + embeds[mask] += self.alpha * embeds_override return embeds @@ -129,7 +145,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) +def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/train_ti.py b/train_ti.py index 8dde1ba..0ad7574 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -451,21 +451,10 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." - ) - parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." - ) - parser.add_argument( - "--emb_decay", - default=1e2, + "--emb_alpha", + default=1.0, type=float, - help="Embedding decay factor." + help="Embedding alpha." ) parser.add_argument( "--noise_timesteps", @@ -567,16 +556,16 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer in ('adam', 'adam8bit'): - args.adam_beta1 = 0.9 - elif args.optimizer == 'lion': + if args.optimizer == 'lion': args.adam_beta1 = 0.95 + else: + args.adam_beta1 = 0.9 if args.adam_beta2 is None: - if args.optimizer in ('adam', 'adam8bit'): - args.adam_beta2 = 0.999 - elif args.optimizer == 'lion': + if args.optimizer == 'lion': args.adam_beta2 = 0.98 + else: + args.adam_beta2 = 0.999 return args @@ -611,7 +600,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_alpha) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -755,10 +744,6 @@ def main(): tokenizer=tokenizer, sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, diff --git a/training/functional.py b/training/functional.py index 96ecbc1..1d8e2ee 100644 --- a/training/functional.py +++ b/training/functional.py @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder) + embeddings = patch_managed_embeddings(text_encoder, emb_alpha) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings diff --git a/training/strategy/ti.py b/training/strategy/ti.py index c7520ed..16baa34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks( seed: int, placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], - gradient_checkpointing: bool = False, - use_emb_decay: bool = False, - emb_decay_target: float = 0.4, - emb_decay: float = 1e-2, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, @@ -105,29 +101,11 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): - if use_emb_decay: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - return torch.all(w.grad == 0, dim=1) - @torch.no_grad() def on_after_optimize(zero_ids, lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - if use_emb_decay: - lambda_ = emb_decay * lr - - if lambda_ != 0: - w = text_encoder.text_model.embeddings.temp_token_embedding.weight - - mask = torch.ones(w.shape[0], dtype=torch.bool) - mask[zero_ids] = False - - norm = w[mask, :].norm(dim=-1, keepdim=True) - w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) - def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks( on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, - on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2