summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/ti.py23
2 files changed, 2 insertions, 25 deletions
diff --git a/training/functional.py b/training/functional.py
index 96ecbc1..1d8e2ee 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols):
73 return grid 73 return grid
74 74
75 75
76def get_models(pretrained_model_name_or_path: str): 76def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0):
77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str):
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 82 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 83 pretrained_model_name_or_path, subfolder='scheduler')
84 84
85 embeddings = patch_managed_embeddings(text_encoder) 85 embeddings = patch_managed_embeddings(text_encoder, emb_alpha)
86 86
87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
88 88
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index c7520ed..16baa34 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks(
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False, 34 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1, 36 ema_power: int = 1,
@@ -106,28 +102,10 @@ def textual_inversion_strategy_callbacks(
106 yield 102 yield
107 103
108 @torch.no_grad() 104 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 105 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None: 106 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 107 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118 108
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130
131 def on_log(): 109 def on_log():
132 if ema_embeddings is not None: 110 if ema_embeddings is not None:
133 return {"ema_decay": ema_embeddings.decay} 111 return {"ema_decay": ema_embeddings.decay}
@@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks(
171 on_accum_model=on_accum_model, 149 on_accum_model=on_accum_model,
172 on_train=on_train, 150 on_train=on_train,
173 on_eval=on_eval, 151 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize, 152 on_after_optimize=on_after_optimize,
176 on_log=on_log, 153 on_log=on_log,
177 on_checkpoint=on_checkpoint, 154 on_checkpoint=on_checkpoint,