summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-04 07:30:43 +0200
commit30b557c8e1f03b4748ac3efca599ff51d66561cb (patch)
tree59aaacde83a7a44dc267c64455f6dc2cfb90c01f /training
parentImproved sparse embeddings (diff)
downloadtextual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.gz
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.tar.bz2
textual-inversion-diff-30b557c8e1f03b4748ac3efca599ff51d66561cb.zip
TI: Bring back old embedding decay
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/ti.py22
2 files changed, 23 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 1d8e2ee..96ecbc1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols):
73 return grid 73 return grid
74 74
75 75
76def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): 76def get_models(pretrained_model_name_or_path: str):
77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0):
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 82 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 83 pretrained_model_name_or_path, subfolder='scheduler')
84 84
85 embeddings = patch_managed_embeddings(text_encoder, emb_alpha) 85 embeddings = patch_managed_embeddings(text_encoder)
86 86
87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
88 88
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 95128da..9df160a 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks(
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 use_emb_decay: bool = False,
35 emb_decay_target: float = 0.4,
36 emb_decay: float = 1e-2,
34 use_ema: bool = False, 37 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 38 ema_inv_gamma: float = 1.0,
36 ema_power: int = 1, 39 ema_power: int = 1,
@@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks(
102 yield 105 yield
103 106
104 @torch.no_grad() 107 @torch.no_grad()
105 def on_after_optimize(zero_ids, lr: float): 108 def on_before_optimize(lr: float, epoch: int):
109 if use_emb_decay:
110 return torch.stack([
111 p
112 for p in text_encoder.text_model.embeddings.token_override_embedding.params
113 if p.grad is not None
114 ])
115
116 @torch.no_grad()
117 def on_after_optimize(w, lr: float):
106 if ema_embeddings is not None: 118 if ema_embeddings is not None:
107 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
108 120
121 if use_emb_decay:
122 lambda_ = emb_decay * lr
123
124 if lambda_ != 0:
125 norm = w[:, :].norm(dim=-1, keepdim=True)
126 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
127
109 def on_log(): 128 def on_log():
110 if ema_embeddings is not None: 129 if ema_embeddings is not None:
111 return {"ema_decay": ema_embeddings.decay} 130 return {"ema_decay": ema_embeddings.decay}
@@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks(
149 on_accum_model=on_accum_model, 168 on_accum_model=on_accum_model,
150 on_train=on_train, 169 on_train=on_train,
151 on_eval=on_eval, 170 on_eval=on_eval,
171 on_before_optimize=on_before_optimize,
152 on_after_optimize=on_after_optimize, 172 on_after_optimize=on_after_optimize,
153 on_log=on_log, 173 on_log=on_log,
154 on_checkpoint=on_checkpoint, 174 on_checkpoint=on_checkpoint,