diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 4 | ||||
-rw-r--r-- | training/strategy/lora.py | 8 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 8 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index c6ceb20..695a24f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -456,7 +456,7 @@ def train_loop( | |||
456 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, |
457 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, |
458 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, |
459 | cycle: int = 1, | 459 | cycle: int = 0, |
460 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, |
461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
@@ -537,7 +537,7 @@ def train_loop( | |||
537 | 537 | ||
538 | logs = {} | 538 | logs = {} |
539 | 539 | ||
540 | with on_train(epoch): | 540 | with on_train(cycle): |
541 | for step, batch in enumerate(train_dataloader): | 541 | for step, batch in enumerate(train_dataloader): |
542 | loss, acc, bsz = loss_step(step, batch, cache) | 542 | loss, acc, bsz = loss_step(step, batch, cache) |
543 | loss /= gradient_accumulation_steps | 543 | loss /= gradient_accumulation_steps |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -35,6 +35,7 @@ def lora_strategy_callbacks( | |||
35 | placeholder_tokens: list[str], | 35 | placeholder_tokens: list[str], |
36 | placeholder_token_ids: list[list[int]], | 36 | placeholder_token_ids: list[list[int]], |
37 | pti_mode: bool = False, | 37 | pti_mode: bool = False, |
38 | train_text_encoder_cycles: int = 99999, | ||
38 | use_emb_decay: bool = False, | 39 | use_emb_decay: bool = False, |
39 | emb_decay_target: float = 0.4, | 40 | emb_decay_target: float = 0.4, |
40 | emb_decay: float = 1e-2, | 41 | emb_decay: float = 1e-2, |
@@ -66,10 +67,11 @@ def lora_strategy_callbacks( | |||
66 | ) | 67 | ) |
67 | 68 | ||
68 | @contextmanager | 69 | @contextmanager |
69 | def on_train(epoch: int): | 70 | def on_train(cycle: int): |
70 | unet.train() | 71 | unet.train() |
71 | text_encoder.train() | 72 | if cycle < train_text_encoder_cycles: |
72 | tokenizer.train() | 73 | text_encoder.train() |
74 | tokenizer.train() | ||
73 | yield | 75 | yield |
74 | 76 | ||
75 | @contextmanager | 77 | @contextmanager |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f330cb7..6bc1d7d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks( | |||
90 | return nullcontext() | 90 | return nullcontext() |
91 | 91 | ||
92 | @contextmanager | 92 | @contextmanager |
93 | def on_train(epoch: int): | 93 | def on_train(cycle: int): |
94 | text_encoder.train() | 94 | text_encoder.train() |
95 | tokenizer.train() | 95 | tokenizer.train() |
96 | yield | 96 | yield |