diff options
author | Volpeon <git@volpeon.ink> | 2023-04-21 11:43:50 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-21 11:43:50 +0200 |
commit | 7da4f0485032bb8b8acfc678546ffcea3a23a44b (patch) | |
tree | 1e7880189df21132861114b5dbf4c614405c9855 /training/strategy/lora.py | |
parent | Fix PTI (diff) | |
download | textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.gz textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.bz2 textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.zip |
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -35,6 +35,7 @@ def lora_strategy_callbacks( | |||
35 | placeholder_tokens: list[str], | 35 | placeholder_tokens: list[str], |
36 | placeholder_token_ids: list[list[int]], | 36 | placeholder_token_ids: list[list[int]], |
37 | pti_mode: bool = False, | 37 | pti_mode: bool = False, |
38 | train_text_encoder_cycles: int = 99999, | ||
38 | use_emb_decay: bool = False, | 39 | use_emb_decay: bool = False, |
39 | emb_decay_target: float = 0.4, | 40 | emb_decay_target: float = 0.4, |
40 | emb_decay: float = 1e-2, | 41 | emb_decay: float = 1e-2, |
@@ -66,10 +67,11 @@ def lora_strategy_callbacks( | |||
66 | ) | 67 | ) |
67 | 68 | ||
68 | @contextmanager | 69 | @contextmanager |
69 | def on_train(epoch: int): | 70 | def on_train(cycle: int): |
70 | unet.train() | 71 | unet.train() |
71 | text_encoder.train() | 72 | if cycle < train_text_encoder_cycles: |
72 | tokenizer.train() | 73 | text_encoder.train() |
74 | tokenizer.train() | ||
73 | yield | 75 | yield |
74 | 76 | ||
75 | @contextmanager | 77 | @contextmanager |