summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
committerVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
commit7da4f0485032bb8b8acfc678546ffcea3a23a44b (patch)
tree1e7880189df21132861114b5dbf4c614405c9855 /training/strategy/lora.py
parentFix PTI (diff)
downloadtextual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.gz
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.bz2
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.zip
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 5c3012e..1f0a117 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -35,6 +35,7 @@ def lora_strategy_callbacks(
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False, 37 pti_mode: bool = False,
38 train_text_encoder_cycles: int = 99999,
38 use_emb_decay: bool = False, 39 use_emb_decay: bool = False,
39 emb_decay_target: float = 0.4, 40 emb_decay_target: float = 0.4,
40 emb_decay: float = 1e-2, 41 emb_decay: float = 1e-2,
@@ -66,10 +67,11 @@ def lora_strategy_callbacks(
66 ) 67 )
67 68
68 @contextmanager 69 @contextmanager
69 def on_train(epoch: int): 70 def on_train(cycle: int):
70 unet.train() 71 unet.train()
71 text_encoder.train() 72 if cycle < train_text_encoder_cycles:
72 tokenizer.train() 73 text_encoder.train()
74 tokenizer.train()
73 yield 75 yield
74 76
75 @contextmanager 77 @contextmanager