diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-07 11:02:47 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-07 11:02:47 +0200 |
| commit | f5b86b44565aaaa92543989a85ea5d88ca9b1c0c (patch) | |
| tree | df02bdcf757743708001fe70e9db2c3e2b9b4af9 | |
| parent | Update (diff) | |
| download | textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.gz textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.tar.bz2 textual-inversion-diff-f5b86b44565aaaa92543989a85ea5d88ca9b1c0c.zip | |
Fix
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 7 | ||||
| -rw-r--r-- | training/strategy/ti.py | 7 |
3 files changed, 9 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index 0b26965..6de3a75 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -927,7 +927,7 @@ def main(): | |||
| 927 | # LORA | 927 | # LORA |
| 928 | # -------------------------------------------------------------------------------- | 928 | # -------------------------------------------------------------------------------- |
| 929 | 929 | ||
| 930 | lora_output_dir = output_dir / "pti" | 930 | lora_output_dir = output_dir / "lora" |
| 931 | lora_checkpoint_output_dir = lora_output_dir / "model" | 931 | lora_checkpoint_output_dir = lora_output_dir / "model" |
| 932 | lora_sample_output_dir = lora_output_dir / "samples" | 932 | lora_sample_output_dir = lora_output_dir / "samples" |
| 933 | 933 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index d51a2f3..6730dc9 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -85,15 +85,16 @@ def lora_strategy_callbacks( | |||
| 85 | ) | 85 | ) |
| 86 | 86 | ||
| 87 | if use_emb_decay: | 87 | if use_emb_decay: |
| 88 | return torch.stack([ | 88 | params = [ |
| 89 | p | 89 | p |
| 90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 90 | for p in text_encoder.text_model.embeddings.token_override_embedding.params |
| 91 | if p.grad is not None | 91 | if p.grad is not None |
| 92 | ]) | 92 | ] |
| 93 | return torch.stack(params) if len(params) != 0 else None | ||
| 93 | 94 | ||
| 94 | @torch.no_grad() | 95 | @torch.no_grad() |
| 95 | def on_after_optimize(w, lr: float): | 96 | def on_after_optimize(w, lr: float): |
| 96 | if use_emb_decay: | 97 | if use_emb_decay and w is not None: |
| 97 | lambda_ = emb_decay * lr | 98 | lambda_ = emb_decay * lr |
| 98 | 99 | ||
| 99 | if lambda_ != 0: | 100 | if lambda_ != 0: |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9df160a..55e9934 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -107,18 +107,19 @@ def textual_inversion_strategy_callbacks( | |||
| 107 | @torch.no_grad() | 107 | @torch.no_grad() |
| 108 | def on_before_optimize(lr: float, epoch: int): | 108 | def on_before_optimize(lr: float, epoch: int): |
| 109 | if use_emb_decay: | 109 | if use_emb_decay: |
| 110 | return torch.stack([ | 110 | params = [ |
| 111 | p | 111 | p |
| 112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | 112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params |
| 113 | if p.grad is not None | 113 | if p.grad is not None |
| 114 | ]) | 114 | ] |
| 115 | return torch.stack(params) if len(params) != 0 else None | ||
| 115 | 116 | ||
| 116 | @torch.no_grad() | 117 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lr: float): | 118 | def on_after_optimize(w, lr: float): |
| 118 | if ema_embeddings is not None: | 119 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 120 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
| 120 | 121 | ||
| 121 | if use_emb_decay: | 122 | if use_emb_decay and w is not None: |
| 122 | lambda_ = emb_decay * lr | 123 | lambda_ = emb_decay * lr |
| 123 | 124 | ||
| 124 | if lambda_ != 0: | 125 | if lambda_ != 0: |
