From 07e8111a98afbf912ba99a28e2d7ea647b2d29fc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 09:17:44 +0100 Subject: Added TI decay start offset --- train_dreambooth.py | 2 +- train_ti.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index da3a075..0fe590f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -907,7 +907,7 @@ def main(): on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, ) - lr_finder.run(num_epochs=100, end_lr=1e3) + lr_finder.run(num_epochs=100, end_lr=1e2) plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() diff --git a/train_ti.py b/train_ti.py index 3b7e3b1..e18ee38 100644 --- a/train_ti.py +++ b/train_ti.py @@ -415,10 +415,16 @@ def parse_args(): ) parser.add_argument( "--decay_factor", - default=100, + default=1, type=float, help="Embedding decay factor." ) + parser.add_argument( + "--decay_start", + default=1e-4, + type=float, + help="Embedding decay start offset." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -833,7 +839,7 @@ def main(): def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize( args.decay_target, - min(1.0, args.decay_factor * lr) + min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) ) loop = partial( -- cgit v1.2.3-54-g00ecf