From d488f66c78e444d03c4ef8a957b82f8b239379d0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:31:24 +0200 Subject: Fix --- models/clip/embeddings.py | 2 +- models/lora.py | 8 ++++---- training/strategy/ti.py | 19 ------------------- 3 files changed, 5 insertions(+), 24 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 60c1b20..840f8ae 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -2,7 +2,6 @@ from typing import Union, Optional from pathlib import Path import torch -import torch.nn as nn from safetensors import safe_open from safetensors.torch import save_file @@ -64,6 +63,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): token_ids = torch.tensor(token_ids, dtype=torch.long) + self.token_embedding.mark_trainable(token_ids) self.token_embedding.weight.data[token_ids] = initializer def load_embed(self, input_ids: list[int], filename: Path): diff --git a/models/lora.py b/models/lora.py index c0f74a6..98d4d2c 100644 --- a/models/lora.py +++ b/models/lora.py @@ -83,11 +83,11 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if new_ids.shape[0] == 0: return - n = self.trainable_ids.shape[0] - self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) + n1 = self.lora_A.shape[1] + n2 = n1 + new_ids.shape[0] + self.trainable_ids[new_ids] = torch.arange(n1, n2) - lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) - lora_A.data[:n] = self.lora_A.data + lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) self.lora_A = lora_A def reset_parameters(self): diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 49236c6..f0b84b5 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -103,29 +103,11 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - @torch.no_grad() - def on_before_optimize(epoch: int): - if use_emb_decay: - params = [ - p - for p in text_encoder.text_model.embeddings.token_embedding.parameters() - if p.grad is not None - ] - return torch.stack(params) if len(params) != 0 else None - @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) - if use_emb_decay and w is not None: - lr = lrs["emb"] if "emb" in lrs else lrs["0"] - lambda_ = emb_decay * lr - - if lambda_ != 0: - norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) - def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -166,7 +148,6 @@ def textual_inversion_strategy_callbacks( return TrainingCallbacks( on_train=on_train, on_eval=on_eval, - on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2