From 7da4f0485032bb8b8acfc678546ffcea3a23a44b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Apr 2023 11:43:50 +0200 Subject: Update --- train_lora.py | 6 ++++-- training/functional.py | 4 ++-- training/strategy/lora.py | 8 +++++--- training/strategy/ti.py | 2 +- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/train_lora.py b/train_lora.py index 0d8b8cb..1d1485d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -873,7 +873,6 @@ def main(): seed=args.seed, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, - offset_noise_strength=args.offset_noise_strength, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, @@ -984,13 +983,14 @@ def main(): lr_scheduler=pti_lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, - cycle=1, + cycle=0, pti_mode=True, # -- group_labels=["emb"], sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, + offset_noise_strength=0, no_val=True, ) @@ -1132,11 +1132,13 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, global_step_offset=training_iter * num_train_steps, cycle=training_iter, + train_text_encoder_cycles=args.train_text_encoder_cycles, # -- group_labels=group_labels, sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, + offset_noise_strength=args.offset_noise_strength, no_val=args.valid_set_size == 0, ) diff --git a/training/functional.py b/training/functional.py index c6ceb20..695a24f 100644 --- a/training/functional.py +++ b/training/functional.py @@ -456,7 +456,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - cycle: int = 1, + cycle: int = 0, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -537,7 +537,7 @@ def train_loop( logs = {} - with on_train(epoch): + with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) loss /= gradient_accumulation_steps diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -35,6 +35,7 @@ def lora_strategy_callbacks( placeholder_tokens: list[str], placeholder_token_ids: list[list[int]], pti_mode: bool = False, + train_text_encoder_cycles: int = 99999, use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay: float = 1e-2, @@ -66,10 +67,11 @@ def lora_strategy_callbacks( ) @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() - text_encoder.train() - tokenizer.train() + if cycle < train_text_encoder_cycles: + text_encoder.train() + tokenizer.train() yield @contextmanager diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f330cb7..6bc1d7d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): text_encoder.train() tokenizer.train() yield -- cgit v1.2.3-70-g09d2