diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 16:30:36 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 16:30:36 +0200 |
commit | c96073646bbb638d7d78fdd7d9fdeed08d1454b5 (patch) | |
tree | 3e0846964fa127844d652e2dee081cd67e672e6a /training/strategy | |
parent | Update (diff) | |
download | textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.gz textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.tar.bz2 textual-inversion-diff-c96073646bbb638d7d78fdd7d9fdeed08d1454b5.zip |
Experimental: TI via LoRA
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/ti.py | 30 |
1 files changed, 4 insertions, 26 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..19b8d25 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks( | |||
32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
34 | gradient_checkpointing: bool = False, | 34 | gradient_checkpointing: bool = False, |
35 | use_emb_decay: bool = False, | ||
36 | emb_decay_target: float = 0.4, | ||
37 | emb_decay: float = 1e-2, | ||
38 | use_ema: bool = False, | 35 | use_ema: bool = False, |
39 | ema_inv_gamma: float = 1.0, | 36 | ema_inv_gamma: float = 1.0, |
40 | ema_power: int = 1, | 37 | ema_power: int = 1, |
@@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks( | |||
73 | 70 | ||
74 | if use_ema: | 71 | if use_ema: |
75 | ema_embeddings = EMAModel( | 72 | ema_embeddings = EMAModel( |
76 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 73 | text_encoder.text_model.embeddings.overlay.parameters(), |
77 | inv_gamma=ema_inv_gamma, | 74 | inv_gamma=ema_inv_gamma, |
78 | power=ema_power, | 75 | power=ema_power, |
79 | max_value=ema_max_decay, | 76 | max_value=ema_max_decay, |
@@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks( | |||
85 | def ema_context(): | 82 | def ema_context(): |
86 | if ema_embeddings is not None: | 83 | if ema_embeddings is not None: |
87 | return ema_embeddings.apply_temporary( | 84 | return ema_embeddings.apply_temporary( |
88 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 85 | text_encoder.text_model.embeddings.overlay.parameters() |
89 | ) | 86 | ) |
90 | else: | 87 | else: |
91 | return nullcontext() | 88 | return nullcontext() |
92 | 89 | ||
93 | def on_accum_model(): | 90 | def on_accum_model(): |
94 | return text_encoder.text_model.embeddings.temp_token_embedding | 91 | return text_encoder.text_model.embeddings.overlay |
95 | 92 | ||
96 | @contextmanager | 93 | @contextmanager |
97 | def on_train(epoch: int): | 94 | def on_train(epoch: int): |
@@ -106,27 +103,9 @@ def textual_inversion_strategy_callbacks( | |||
106 | yield | 103 | yield |
107 | 104 | ||
108 | @torch.no_grad() | 105 | @torch.no_grad() |
109 | def on_before_optimize(lr: float, epoch: int): | ||
110 | if use_emb_decay: | ||
111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
112 | return torch.all(w.grad == 0, dim=1) | ||
113 | |||
114 | @torch.no_grad() | ||
115 | def on_after_optimize(zero_ids, lr: float): | 106 | def on_after_optimize(zero_ids, lr: float): |
116 | if ema_embeddings is not None: | 107 | if ema_embeddings is not None: |
117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) |
118 | |||
119 | if use_emb_decay: | ||
120 | lambda_ = emb_decay * lr | ||
121 | |||
122 | if lambda_ != 0: | ||
123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
124 | |||
125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
126 | mask[zero_ids] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
130 | 109 | ||
131 | def on_log(): | 110 | def on_log(): |
132 | if ema_embeddings is not None: | 111 | if ema_embeddings is not None: |
@@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks( | |||
171 | on_accum_model=on_accum_model, | 150 | on_accum_model=on_accum_model, |
172 | on_train=on_train, | 151 | on_train=on_train, |
173 | on_eval=on_eval, | 152 | on_eval=on_eval, |
174 | on_before_optimize=on_before_optimize, | ||
175 | on_after_optimize=on_after_optimize, | 153 | on_after_optimize=on_after_optimize, |
176 | on_log=on_log, | 154 | on_log=on_log, |
177 | on_checkpoint=on_checkpoint, | 155 | on_checkpoint=on_checkpoint, |