From 71f4a40bb48be4f2759ba2d83faff39691cb2955 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 19:03:25 +0200 Subject: Improved automation caps --- train_lora.py | 53 +++++++++++++++++++++++++---------------- train_ti.py | 53 +++++++++++++++++++++++++---------------- training/functional.py | 17 ++++++------- training/strategy/dreambooth.py | 4 ++-- training/strategy/lora.py | 4 ++-- training/strategy/ti.py | 23 ++++++++++++++++-- 6 files changed, 100 insertions(+), 54 deletions(-) diff --git a/train_lora.py b/train_lora.py index 4d4c16a..ba5aee1 100644 --- a/train_lora.py +++ b/train_lora.py @@ -84,9 +84,9 @@ def parse_args(): ) parser.add_argument( "--auto_cycles", - type=int, - default=1, - help="How many cycles to run automatically." + type=str, + default="o", + help="Cycles to run automatically." ) parser.add_argument( "--cycle_decay", @@ -94,11 +94,6 @@ def parse_args(): default=1.0, help="Learning rate decay per cycle." ) - parser.add_argument( - "--cycle_constant", - action="store_true", - help="Use constant LR on cycles > 1." - ) parser.add_argument( "--placeholder_tokens", type=str, @@ -920,7 +915,6 @@ def main(): annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, end_lr=1e2, mid_point=args.lr_mid_point, ) @@ -964,20 +958,38 @@ def main(): lora_sample_output_dir = output_dir / lora_project / "samples" + auto_cycles = list(args.auto_cycles) + lr_scheduler = args.lr_scheduler + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + while True: - if training_iter >= args.auto_cycles: - response = input("Run another cycle? [y/n] ") - if response.lower().strip() == "n": - break + if len(auto_cycles) != 0: + response = auto_cycles.pop(0) + else: + response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + + if response.lower().strip() == "o": + lr_scheduler = "one_cycle" + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + if response.lower().strip() == "w": + lr_scheduler = "constant" + lr_warmup_epochs = num_train_epochs + if response.lower().strip() == "c": + lr_scheduler = "constant" + lr_warmup_epochs = 0 + if response.lower().strip() == "d": + lr_scheduler = "cosine" + lr_warmup_epochs = 0 + lr_cycles = 1 + elif response.lower().strip() == "s": + break print("") print(f"============ LoRA cycle {training_iter + 1} ============") print("") - if args.cycle_constant and training_iter == 1: - args.lr_scheduler = "constant" - args.lr_warmup_epochs = 0 - params_to_optimize = [] if len(args.placeholder_tokens) != 0: @@ -1012,12 +1024,13 @@ def main(): lora_optimizer = create_optimizer(params_to_optimize) lora_lr_scheduler = create_lr_scheduler( - args.lr_scheduler, + lr_scheduler, gradient_accumulation_steps=args.gradient_accumulation_steps, optimizer=lora_optimizer, num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, + cycles=lr_cycles, + warmup_epochs=lr_warmup_epochs, ) lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" @@ -1031,7 +1044,7 @@ def main(): num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, global_step_offset=training_iter * num_train_steps, - initial_samples=training_iter == 0, + cycle=training_iter, # -- group_labels=group_labels, sample_output_dir=lora_sample_output_dir, diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -68,9 +68,9 @@ def parse_args(): ) parser.add_argument( "--auto_cycles", - type=int, - default=1, - help="How many cycles to run automatically." + type=str, + default="o", + help="Cycles to run automatically." ) parser.add_argument( "--cycle_decay", @@ -78,11 +78,6 @@ def parse_args(): default=1.0, help="Learning rate decay per cycle." ) - parser.add_argument( - "--cycle_constant", - action="store_true", - help="Use constant LR on cycles > 1." - ) parser.add_argument( "--placeholder_tokens", type=str, @@ -921,27 +916,45 @@ def main(): sample_output_dir = output_dir / project / "samples" + auto_cycles = list(args.auto_cycles) + lr_scheduler = args.lr_scheduler + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + while True: - if training_iter >= args.auto_cycles: - response = input("Run another cycle? [y/n] ") - if response.lower().strip() == "n": - break + if len(auto_cycles) != 0: + response = auto_cycles.pop(0) + else: + response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + + if response.lower().strip() == "o": + lr_scheduler = "one_cycle" + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + if response.lower().strip() == "w": + lr_scheduler = "constant" + lr_warmup_epochs = num_train_epochs + if response.lower().strip() == "c": + lr_scheduler = "constant" + lr_warmup_epochs = 0 + if response.lower().strip() == "d": + lr_scheduler = "cosine" + lr_warmup_epochs = 0 + lr_cycles = 1 + elif response.lower().strip() == "s": + break print("") print(f"------------ TI cycle {training_iter + 1} ------------") print("") - if args.cycle_constant and training_iter == 1: - args.lr_scheduler = "constant" - args.lr_warmup_epochs = 0 - optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), lr=learning_rate, ) lr_scheduler = get_scheduler( - args.lr_scheduler, + lr_scheduler, optimizer=optimizer, num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -950,10 +963,10 @@ def main(): annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, + cycles=lr_cycles, end_lr=1e3, train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, + warmup_epochs=lr_warmup_epochs, mid_point=args.lr_mid_point, ) @@ -966,7 +979,7 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, global_step_offset=training_iter * num_train_steps, - initial_samples=training_iter == 0, + cycle=training_iter, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py @@ -42,7 +42,7 @@ class TrainingCallbacks(): on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) - on_sample: Callable[[int], None] = const() + on_sample: Callable[[int, int], None] = const() on_checkpoint: Callable[[int, str], None] = const() @@ -96,6 +96,7 @@ def save_samples( output_dir: Path, seed: int, step: int, + cycle: int = 1, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, @@ -125,7 +126,7 @@ def save_samples( for pool, data, gen in datasets: all_samples = [] - file_path = output_dir / pool / f"step_{step}.jpg" + file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) @@ -455,7 +456,7 @@ def train_loop( sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - initial_samples: bool = True, + cycle: int = 1, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -518,12 +519,12 @@ def train_loop( try: for epoch in range(num_epochs): if accelerator.is_main_process: - if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): + if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() with on_eval(): - on_sample(global_step) + on_sample(cycle, global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() @@ -648,7 +649,7 @@ def train_loop( if accelerator.is_main_process: print("Finished!") with on_eval(): - on_sample(global_step) + on_sample(cycle, global_step) on_checkpoint(global_step, "end") except KeyboardInterrupt: @@ -680,7 +681,7 @@ def train( sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, - initial_samples: bool = True, + cycle: int = 1, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, @@ -731,7 +732,7 @@ def train( sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, - initial_samples=initial_samples, + cycle=cycle, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( torch.cuda.empty_cache() @torch.no_grad() - def on_sample(step): + def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) @@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -146,11 +146,11 @@ def lora_strategy_callbacks( torch.cuda.empty_cache() @torch.no_grad() - def on_sample(step): + def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) del unet_, text_encoder_ diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -103,11 +103,29 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield + @torch.no_grad() + def on_before_optimize(epoch: int): + if use_emb_decay: + params = [ + p + for p in text_encoder.text_model.embeddings.token_embedding.parameters() + if p.grad is not None + ] + return torch.stack(params) if len(params) != 0 else None + @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + if use_emb_decay and w is not None: + lr = lrs["emb"] or lrs["0"] + lambda_ = emb_decay * lr + + if lambda_ != 0: + norm = w[:, :].norm(dim=-1, keepdim=True) + w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + def on_log(): if ema_embeddings is not None: return {"ema_decay": ema_embeddings.decay} @@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( ) @torch.no_grad() - def on_sample(step): + def on_sample(cycle, step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) @@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( unet_.to(dtype=weight_dtype) text_encoder_.to(dtype=weight_dtype) - save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) unet_.to(dtype=orig_unet_dtype) text_encoder_.to(dtype=orig_text_encoder_dtype) @@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( return TrainingCallbacks( on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2