From 7da4f0485032bb8b8acfc678546ffcea3a23a44b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Apr 2023 11:43:50 +0200 Subject: Update --- training/functional.py | 4 ++-- training/strategy/lora.py | 8 +++++--- training/strategy/ti.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index c6ceb20..695a24f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -456,7 +456,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - cycle: int = 1, + cycle: int = 0, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -537,7 +537,7 @@ def train_loop( logs = {} - with on_train(epoch): + with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) loss /= gradient_accumulation_steps diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -35,6 +35,7 @@ def lora_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], pti_mode: bool = False, + train_text_encoder_cycles: int = 99999, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -66,10 +67,11 @@ def lora_strategy_callbacks( ) @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() - text_encoder.train() - tokenizer.train() + if cycle < train_text_encoder_cycles: + text_encoder.train() + tokenizer.train() yield @contextmanager diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f330cb7..6bc1d7d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): text_encoder.train() tokenizer.train() yield -- cgit v1.2.3-54-g00ecf