summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/lora.py8
-rw-r--r--training/strategy/ti.py2
3 files changed, 8 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index c6ceb20..695a24f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -456,7 +456,7 @@ def train_loop(
456 sample_frequency: int = 10, 456 sample_frequency: int = 10,
457 checkpoint_frequency: int = 50, 457 checkpoint_frequency: int = 50,
458 milestone_checkpoints: bool = True, 458 milestone_checkpoints: bool = True,
459 cycle: int = 1, 459 cycle: int = 0,
460 global_step_offset: int = 0, 460 global_step_offset: int = 0,
461 num_epochs: int = 100, 461 num_epochs: int = 100,
462 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
@@ -537,7 +537,7 @@ def train_loop(
537 537
538 logs = {} 538 logs = {}
539 539
540 with on_train(epoch): 540 with on_train(cycle):
541 for step, batch in enumerate(train_dataloader): 541 for step, batch in enumerate(train_dataloader):
542 loss, acc, bsz = loss_step(step, batch, cache) 542 loss, acc, bsz = loss_step(step, batch, cache)
543 loss /= gradient_accumulation_steps 543 loss /= gradient_accumulation_steps
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 5c3012e..1f0a117 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -35,6 +35,7 @@ def lora_strategy_callbacks(
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False, 37 pti_mode: bool = False,
38 train_text_encoder_cycles: int = 99999,
38 use_emb_decay: bool = False, 39 use_emb_decay: bool = False,
39 emb_decay_target: float = 0.4, 40 emb_decay_target: float = 0.4,
40 emb_decay: float = 1e-2, 41 emb_decay: float = 1e-2,
@@ -66,10 +67,11 @@ def lora_strategy_callbacks(
66 ) 67 )
67 68
68 @contextmanager 69 @contextmanager
69 def on_train(epoch: int): 70 def on_train(cycle: int):
70 unet.train() 71 unet.train()
71 text_encoder.train() 72 if cycle < train_text_encoder_cycles:
72 tokenizer.train() 73 text_encoder.train()
74 tokenizer.train()
73 yield 75 yield
74 76
75 @contextmanager 77 @contextmanager
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f330cb7..6bc1d7d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks(
90 return nullcontext() 90 return nullcontext()
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(cycle: int):
94 text_encoder.train() 94 text_encoder.train()
95 tokenizer.train() 95 tokenizer.train()
96 yield 96 yield