summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
committerVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
commit7da4f0485032bb8b8acfc678546ffcea3a23a44b (patch)
tree1e7880189df21132861114b5dbf4c614405c9855 /training/strategy
parentFix PTI (diff)
downloadtextual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.gz
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.bz2
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.zip
Update
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/lora.py8
-rw-r--r--training/strategy/ti.py2
2 files changed, 6 insertions, 4 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 5c3012e..1f0a117 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -35,6 +35,7 @@ def lora_strategy_callbacks(
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False, 37 pti_mode: bool = False,
38 train_text_encoder_cycles: int = 99999,
38 use_emb_decay: bool = False, 39 use_emb_decay: bool = False,
39 emb_decay_target: float = 0.4, 40 emb_decay_target: float = 0.4,
40 emb_decay: float = 1e-2, 41 emb_decay: float = 1e-2,
@@ -66,10 +67,11 @@ def lora_strategy_callbacks(
66 ) 67 )
67 68
68 @contextmanager 69 @contextmanager
69 def on_train(epoch: int): 70 def on_train(cycle: int):
70 unet.train() 71 unet.train()
71 text_encoder.train() 72 if cycle < train_text_encoder_cycles:
72 tokenizer.train() 73 text_encoder.train()
74 tokenizer.train()
73 yield 75 yield
74 76
75 @contextmanager 77 @contextmanager
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f330cb7..6bc1d7d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks(
90 return nullcontext() 90 return nullcontext()
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(cycle: int):
94 text_encoder.train() 94 text_encoder.train()
95 tokenizer.train() 95 tokenizer.train()
96 yield 96 yield