diff options
author | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-08 17:38:49 +0200 |
commit | 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 (patch) | |
tree | 19bd8802b6cfd941797beabfc0bb2595ffb00b5f /training/strategy | |
parent | Fix TI (diff) | |
download | textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.gz textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.tar.bz2 textual-inversion-diff-9f5f70cb2a8919cb07821f264bf0fd75bfa10584.zip |
Update
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 12 |
2 files changed, 7 insertions, 7 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cfdc504..ae85401 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
93 | if use_emb_decay: | 93 | if use_emb_decay: |
94 | params = [ | 94 | params = [ |
95 | p | 95 | p |
96 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() |
97 | if p.grad is not None | 97 | if p.grad is not None |
98 | ] | 98 | ] |
99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 720ebf3..289d6bd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
72 | 72 | ||
73 | if use_ema: | 73 | if use_ema: |
74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
75 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | 75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), |
76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
77 | power=ema_power, | 77 | power=ema_power, |
78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
@@ -84,20 +84,20 @@ def textual_inversion_strategy_callbacks( | |||
84 | def ema_context(): | 84 | def ema_context(): |
85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
87 | text_encoder.text_model.embeddings.token_override_embedding.params.parameters() | 87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() |
88 | ) | 88 | ) |
89 | else: | 89 | else: |
90 | return nullcontext() | 90 | return nullcontext() |
91 | 91 | ||
92 | @contextmanager | 92 | @contextmanager |
93 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
94 | text_encoder.text_model.embeddings.token_override_embedding.params.train() | 94 | text_encoder.train() |
95 | tokenizer.train() | 95 | tokenizer.train() |
96 | yield | 96 | yield |
97 | 97 | ||
98 | @contextmanager | 98 | @contextmanager |
99 | def on_eval(): | 99 | def on_eval(): |
100 | text_encoder.text_model.embeddings.token_override_embedding.params.eval() | 100 | text_encoder.eval() |
101 | tokenizer.eval() | 101 | tokenizer.eval() |
102 | 102 | ||
103 | with ema_context(): | 103 | with ema_context(): |
@@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
111 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() |
112 | if p.grad is not None | 112 | if p.grad is not None |
113 | ] | 113 | ] |
114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
@@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] or lrs["0"] | 122 | lr = lrs["emb"] or lrs["0"] |