From 7da4f0485032bb8b8acfc678546ffcea3a23a44b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Apr 2023 11:43:50 +0200 Subject: Update --- training/strategy/lora.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -35,6 +35,7 @@ def lora_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], pti_mode: bool = False, + train_text_encoder_cycles: int = 99999, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -66,10 +67,11 @@ def lora_strategy_callbacks( ) @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() - text_encoder.train() - tokenizer.train() + if cycle < train_text_encoder_cycles: + text_encoder.train() + tokenizer.train() yield @contextmanager -- cgit v1.2.3-54-g00ecf