From 7da4f0485032bb8b8acfc678546ffcea3a23a44b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Apr 2023 11:43:50 +0200 Subject: Update --- training/strategy/lora.py | 8 +++++--- training/strategy/ti.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -35,6 +35,7 @@ def lora_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], pti_mode: bool = False, + train_text_encoder_cycles: int = 99999, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -66,10 +67,11 @@ def lora_strategy_callbacks( ) @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() - text_encoder.train() - tokenizer.train() + if cycle < train_text_encoder_cycles: + text_encoder.train() + tokenizer.train() yield @contextmanager diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f330cb7..6bc1d7d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): text_encoder.train() tokenizer.train() yield -- cgit v1.2.3-70-g09d2