summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 09:17:44 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 09:17:44 +0100
commit07e8111a98afbf912ba99a28e2d7ea647b2d29fc (patch)
tree853aaae6f5ff477b73d5649f3f89303ff8654b70 /train_ti.py
parentCode deduplication (diff)
downloadtextual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.tar.gz
textual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.tar.bz2
textual-inversion-diff-07e8111a98afbf912ba99a28e2d7ea647b2d29fc.zip
Added TI decay start offset
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index 3b7e3b1..e18ee38 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -415,11 +415,17 @@ def parse_args():
415 ) 415 )
416 parser.add_argument( 416 parser.add_argument(
417 "--decay_factor", 417 "--decay_factor",
418 default=100, 418 default=1,
419 type=float, 419 type=float,
420 help="Embedding decay factor." 420 help="Embedding decay factor."
421 ) 421 )
422 parser.add_argument( 422 parser.add_argument(
423 "--decay_start",
424 default=1e-4,
425 type=float,
426 help="Embedding decay start offset."
427 )
428 parser.add_argument(
423 "--noise_timesteps", 429 "--noise_timesteps",
424 type=int, 430 type=int,
425 default=1000, 431 default=1000,
@@ -833,7 +839,7 @@ def main():
833 def on_after_optimize(lr: float): 839 def on_after_optimize(lr: float):
834 text_encoder.text_model.embeddings.normalize( 840 text_encoder.text_model.embeddings.normalize(
835 args.decay_target, 841 args.decay_target,
836 min(1.0, args.decay_factor * lr) 842 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
837 ) 843 )
838 844
839 loop = partial( 845 loop = partial(