From 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 08:40:05 +0200 Subject: Update --- train_dreambooth.py | 19 ++++++++++++++++++- train_ti.py | 42 +++++++++++++++++++++++++++++++++++------ training/functional.py | 3 --- training/strategy/dreambooth.py | 3 +++ training/strategy/ti.py | 7 +++++++ 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 929310b..90ca467 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -839,7 +839,10 @@ def main(): create_optimizer = partial( prodigyopt.Prodigy, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + d0=args.dadaptation_d0, ) args.learning_rate_unet = 1.0 @@ -965,9 +968,23 @@ def main(): }, { "params": ( - param for param in text_encoder.parameters() if param.requires_grad + param + for param in itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ) + if param.requires_grad + ), + "lr": learning_rate_text, + }, + { + "params": ( + param + for param in text_encoder.text_model.embeddings.token_embedding.parameters() + if param.requires_grad ), "lr": learning_rate_text, + "weight_decay": 0, }, ] group_labels = ["unet", "text"] diff --git a/train_ti.py b/train_ti.py index 1d0cb6f..a7d2924 100644 --- a/train_ti.py +++ b/train_ti.py @@ -337,7 +337,16 @@ def parse_args(): "--optimizer", type=str, default="adan", - choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], + choices=[ + "adam", + "adam8bit", + "adan", + "lion", + "dadam", + "dadan", + "adafactor", + "prodigy", + ], help="Optimizer to use", ) parser.add_argument( @@ -819,6 +828,23 @@ def main(): eps=args.adam_epsilon, d0=args.dadaptation_d0, ) + elif args.optimizer == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) + + create_optimizer = partial( + prodigyopt.Prodigy, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + d0=args.dadaptation_d0, + ) + + args.learning_rate = 1.0 else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') @@ -959,7 +985,11 @@ def main(): avg_acc_val = AverageMeter() optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), + ( + param + for param in text_encoder.text_model.embeddings.token_embedding.parameters() + if param.requires_grad + ), lr=args.learning_rate, ) @@ -973,9 +1003,11 @@ def main(): if response.lower().strip() == "o": if args.learning_rate is not None: - learning_rate = args.learning_rate * 2 + learning_rate = ( + args.learning_rate * 2 * (args.cycle_decay**training_iter) + ) else: - learning_rate = args.learning_rate + learning_rate = args.learning_rate * (args.cycle_decay**training_iter) if response.lower().strip() == "o": lr_scheduler = "one_cycle" @@ -1045,8 +1077,6 @@ def main(): ) training_iter += 1 - if learning_rate is not None: - learning_rate *= args.cycle_decay accelerator.end_training() diff --git a/training/functional.py b/training/functional.py index 8917eb7..b60afe3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -786,7 +786,4 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) - text_encoder.forward = MethodType(text_encoder.forward, text_encoder) - unet.forward = MethodType(unet.forward, unet) - accelerator.free_memory() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + unet_.forward = MethodType(unet_.forward, unet_) + text_encoder_.text_model.embeddings.persist(False) with ema_context(): diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") + if postfix == "end": + text_encoder_ = accelerator.unwrap_model( + text_encoder, keep_fp32_wrapper=False + ) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( -- cgit v1.2.3-70-g09d2