From 07e8111a98afbf912ba99a28e2d7ea647b2d29fc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 09:17:44 +0100 Subject: Added TI decay start offset --- train_ti.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 3b7e3b1..e18ee38 100644 --- a/train_ti.py +++ b/train_ti.py @@ -415,10 +415,16 @@ def parse_args(): ) parser.add_argument( "--decay_factor", - default=100, + default=1, type=float, help="Embedding decay factor." ) + parser.add_argument( + "--decay_start", + default=1e-4, + type=float, + help="Embedding decay start offset." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -833,7 +839,7 @@ def main(): def on_after_optimize(lr: float): text_encoder.text_model.embeddings.normalize( args.decay_target, - min(1.0, args.decay_factor * lr) + min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) ) loop = partial( -- cgit v1.2.3-54-g00ecf