From 5e84594c56237cd2c7d7f80858e5da8c11aa3f89 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 07:58:14 +0200 Subject: Update --- train_dreambooth.py | 5 +-- train_lora.py | 75 ++++++++++++++++++++++------------------- train_ti.py | 12 +++---- training/functional.py | 57 ++++++++++++------------------- training/strategy/dreambooth.py | 2 +- training/strategy/lora.py | 12 +++++-- training/strategy/ti.py | 2 +- 7 files changed, 80 insertions(+), 85 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 48921d4..f4d4cbb 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -18,7 +18,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, get_models -from training.lr import plot_metrics from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler from training.util import save_args @@ -692,7 +691,7 @@ def main(): mid_point=args.lr_mid_point, ) - metrics = trainer( + trainer( strategy=dreambooth_strategy, project="dreambooth", train_dataloader=datamodule.train_dataloader, @@ -721,8 +720,6 @@ def main(): sample_image_size=args.sample_image_size, ) - plot_metrics(metrics, output_dir / "lr.png") - if __name__ == "__main__": main() diff --git a/train_lora.py b/train_lora.py index 9f17495..1626be6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -19,7 +19,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models -from training.lr import plot_metrics from training.strategy.lora import lora_strategy from training.optimization import get_scheduler from training.util import save_args @@ -568,6 +567,9 @@ def parse_args(): if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + if len(args.initializer_tokens) == 0: + args.initializer_tokens = args.placeholder_tokens.copy() + if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") @@ -918,7 +920,7 @@ def main(): train_epochs=num_pti_epochs, ) - metrics = trainer( + trainer( strategy=lora_strategy, pti_mode=True, project="pti", @@ -929,12 +931,13 @@ def main(): num_train_epochs=num_pti_epochs, gradient_accumulation_steps=args.pti_gradient_accumulation_steps, # -- + group_labels=["emb"], sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, ) - plot_metrics(metrics, pti_output_dir / "lr.png") + # embeddings.persist() # LORA # -------------------------------------------------------------------------------- @@ -957,34 +960,39 @@ def main(): ) * args.gradient_accumulation_steps lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) - lora_optimizer = create_optimizer( - [ - { - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), - "lr": args.learning_rate_unet, - }, - { - "params": ( - param - for param in itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - if param.requires_grad - ), - "lr": args.learning_rate_text, - }, - { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_text, - "weight_decay": 0, - }, - ] - ) + params_to_optimize = [] + group_labels = [] + if len(args.placeholder_tokens) != 0: + params_to_optimize.append({ + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_text, + "weight_decay": 0, + }) + group_labels.append("emb") + params_to_optimize += [ + { + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), + "lr": args.learning_rate_unet, + }, + { + "params": ( + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad + ), + "lr": args.learning_rate_text, + }, + ] + group_labels += ["unet", "text"] + + lora_optimizer = create_optimizer(params_to_optimize) lora_lr_scheduler = create_lr_scheduler( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -993,7 +1001,7 @@ def main(): train_epochs=num_train_epochs, ) - metrics = trainer( + trainer( strategy=lora_strategy, project="lora", train_dataloader=lora_datamodule.train_dataloader, @@ -1003,13 +1011,12 @@ def main(): num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, # -- + group_labels=group_labels, sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, ) - plot_metrics(metrics, lora_output_dir / "lr.png") - if __name__ == "__main__": main() diff --git a/train_ti.py b/train_ti.py index c1c0eed..48858cc 100644 --- a/train_ti.py +++ b/train_ti.py @@ -17,7 +17,6 @@ import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models -from training.lr import plot_metrics from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -511,12 +510,12 @@ def parse_args(): if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if len(args.initializer_tokens) == 0: - raise ValueError("You must specify --initializer_tokens") - if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + if len(args.initializer_tokens) == 0: + args.initializer_tokens = args.placeholder_tokens.copy() + if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") @@ -856,7 +855,7 @@ def main(): mid_point=args.lr_mid_point, ) - metrics = trainer( + trainer( project="textual_inversion", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, @@ -864,14 +863,13 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, # -- + group_labels=["emb"], sample_output_dir=sample_output_dir, sample_frequency=sample_frequency, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, ) - plot_metrics(metrics, metrics_output_file) - if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: diff --git a/training/functional.py b/training/functional.py index 4d83df1..71b2fe9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -36,8 +36,8 @@ def const(result=None): class TrainingCallbacks(): on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) - on_before_optimize: Callable[[float, int], Any] = const() - on_after_optimize: Callable[[Any, float], None] = const() + on_before_optimize: Callable[[int], Any] = const() + on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() @@ -422,6 +422,7 @@ def train_loop( global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, + group_labels: list[str] = [], callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) @@ -442,10 +443,6 @@ def train_loop( best_acc = 0.0 best_acc_val = 0.0 - lrs = [] - losses = [] - accs = [] - local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, @@ -496,6 +493,8 @@ def train_loop( local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() + logs = {} + with on_train(epoch): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) @@ -506,31 +505,36 @@ def train_loop( avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) - lr = lr_scheduler.get_last_lr()[0] - if torch.is_tensor(lr): - lr = lr.item() - logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), - "lr": lr, } - if isDadaptation: - logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + + lrs: dict[str, float] = {} + for i, lr in enumerate(lr_scheduler.get_last_lr()): + if torch.is_tensor(lr): + lr = lr.item() + label = group_labels[i] if i < len(group_labels) else f"{i}" + logs[f"lr/{label}"] = lr + if isDadaptation: + lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + logs[f"d*lr/{label}"] = lr + lrs[label] = lr + logs.update(on_log()) local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - before_optimize_result = on_before_optimize(lr, epoch) + before_optimize_result = on_before_optimize(epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - on_after_optimize(before_optimize_result, lr) + on_after_optimize(before_optimize_result, lrs) local_progress_bar.update(1) global_progress_bar.update(1) @@ -544,15 +548,6 @@ def train_loop( accelerator.wait_for_everyone() - if isDadaptation: - lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - else: - lr = lr_scheduler.get_last_lr()[0] - if torch.is_tensor(lr): - lr = lr.item() - - lrs.append(lr) - on_after_epoch() if val_dataloader is not None: @@ -597,9 +592,6 @@ def train_loop( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") best_acc_val = avg_acc_val.avg.item() - - losses.append(avg_loss_val.avg.item()) - accs.append(avg_acc_val.avg.item()) else: if accelerator.is_main_process: if avg_acc.avg.item() > best_acc and milestone_checkpoints: @@ -611,9 +603,6 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "milestone") best_acc = avg_acc.avg.item() - losses.append(avg_loss.avg.item()) - accs.append(avg_acc.avg.item()) - # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") @@ -626,8 +615,6 @@ def train_loop( on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt - return lrs, losses, accs - def train( accelerator: Accelerator, @@ -646,6 +633,7 @@ def train( no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, + group_labels: list[str] = [], sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, @@ -692,7 +680,7 @@ def train( if accelerator.is_main_process: accelerator.init_trackers(project) - metrics = train_loop( + train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -705,10 +693,9 @@ def train( global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, + group_labels=group_labels, callbacks=callbacks, ) accelerator.end_training() accelerator.free_memory() - - return metrics diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: params_to_clip.append(text_encoder.parameters()) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -79,10 +79,14 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if not pti_mode: accelerator.clip_grad_norm_( - itertools.chain(unet.parameters(), text_encoder.parameters()), + itertools.chain( + unet.parameters(), + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), max_grad_norm ) @@ -95,7 +99,9 @@ def lora_strategy_callbacks( return torch.stack(params) if len(params) != 0 else None @torch.no_grad() - def on_after_optimize(w, lr: float): + def on_after_optimize(w, lrs: dict[str, float]): + lr = lrs["emb"] or lrs["0"] + if use_emb_decay and w is not None: lambda_ = emb_decay * lr diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(lr: float, epoch: int): + def on_before_optimize(epoch: int): if use_emb_decay: params = [ p -- cgit v1.2.3-70-g09d2