From 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 07:58:14 +0200 Subject: Update --- training/functional.py | 57 ++++++++++++++++------------------------- training/strategy/dreambooth.py | 2 +- training/strategy/lora.py | 12 ++++++--- training/strategy/ti.py | 2 +- 4 files changed, 33 insertions(+), 40 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -36,8 +36,8 @@ def const(result=None): class TrainingCallbacks(): on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) - on_before_optimize: Callable[[float, int], Any] = const() - on_after_optimize: Callable[[Any, float], None] = const() + on_before_optimize: Callable[[int], Any] = const() + on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() @@ -422,6 +422,7 @@ def train_loop( global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, + group_labels: list[str] = [], callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) @@ -442,10 +443,6 @@ def train_loop( best_acc = 0.0 best_acc_val = 0.0 - lrs = [] - losses = [] - accs = [] - local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, @@ -496,6 +493,8 @@ def train_loop( local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() + logs = {} + with on_train(epoch): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) @@ -506,31 +505,36 @@ def train_loop( avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) - lr = lr_scheduler.get_last_lr()[0] - if torch.is_tensor(lr): - lr = lr.item() - logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), - "lr": lr, } - if isDadaptation: - logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + + lrs: dict[str, float] = {} + for i, lr in enumerate(lr_scheduler.get_last_lr()): + if torch.is_tensor(lr): + lr = lr.item() + label = group_labels[i] if i < len(group_labels) else f"{i}" + logs[f"lr/{label}"] = lr + if isDadaptation: + lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + logs[f"d*lr/{label}"] = lr + lrs[label] = lr + logs.update(on_log()) local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - before_optimize_result = on_before_optimize(lr, epoch) + before_optimize_result = on_before_optimize(epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - on_after_optimize(before_optimize_result, lr) + on_after_optimize(before_optimize_result, lrs) local_progress_bar.update(1) global_progress_bar.update(1) @@ -544,15 +548,6 @@ def train_loop( accelerator.wait_for_everyone() - if isDadaptation: - lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - else: - lr = lr_scheduler.get_last_lr()[0] - if torch.is_tensor(lr): - lr = lr.item() - - lrs.append(lr) - on_after_epoch() if val_dataloader is not None: @@ -597,9 +592,6 @@ def train_loop( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") best_acc_val = avg_acc_val.avg.item() - - losses.append(avg_loss_val.avg.item()) - accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: if avg_acc.avg.item() > best_acc and milestone_checkpoints: @@ -611,9 +603,6 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "milestone") best_acc = avg_acc.avg.item() - losses.append(avg_loss.avg.item()) - accs.append(avg_acc.avg.item()) - # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") @@ -626,8 +615,6 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt - return lrs, losses, accs - def train( accelerator: Accelerator, @@ -646,6 +633,7 @@ def train( no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, + group_labels: list[str] = [], sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, @@ -692,7 +680,7 @@ def train( if accelerator.is_main_process: accelerator.init_trackers(project) - metrics = train_loop( + train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -705,10 +693,9 @@ def train( global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, + group_labels=group_labels, callbacks=callbacks, ) accelerator.end_training() accelerator.free_memory() - - return metrics diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -79,10 +79,14 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if not pti_mode: accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), + itertools.chain( + unet.parameters(), + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), max_grad_norm ) @@ -95,7 +99,9 @@ def lora_strategy_callbacks( return torch.stack(params) if len(params) != 0 else None @torch.no_grad() - def on_after_optimize(w, lr: float): + def on_after_optimize(w, lrs: dict[str, float]): + lr = lrs["emb"] or lrs["0"] + if use_emb_decay and w is not None: lambda_ = emb_decay * lr diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if use_emb_decay: params = [ p -- cgit v1.2.3-70-g09d2