summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_ti.py10
2 files changed, 9 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index da3a075..0fe590f 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -907,7 +907,7 @@ def main():
907 on_before_optimize=on_before_optimize, 907 on_before_optimize=on_before_optimize,
908 on_after_optimize=on_after_optimize, 908 on_after_optimize=on_after_optimize,
909 ) 909 )
910 lr_finder.run(num_epochs=100, end_lr=1e3) 910 lr_finder.run(num_epochs=100, end_lr=1e2)
911 911
912 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 912 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
913 plt.close() 913 plt.close()
diff --git a/train_ti.py b/train_ti.py
index 3b7e3b1..e18ee38 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -415,11 +415,17 @@ def parse_args():
415 ) 415 )
416 parser.add_argument( 416 parser.add_argument(
417 "--decay_factor", 417 "--decay_factor",
418 default=100, 418 default=1,
419 type=float, 419 type=float,
420 help="Embedding decay factor." 420 help="Embedding decay factor."
421 ) 421 )
422 parser.add_argument( 422 parser.add_argument(
423 "--decay_start",
424 default=1e-4,
425 type=float,
426 help="Embedding decay start offset."
427 )
428 parser.add_argument(
423 "--noise_timesteps", 429 "--noise_timesteps",
424 type=int, 430 type=int,
425 default=1000, 431 default=1000,
@@ -833,7 +839,7 @@ def main():
833 def on_after_optimize(lr: float): 839 def on_after_optimize(lr: float):
834 text_encoder.text_model.embeddings.normalize( 840 text_encoder.text_model.embeddings.normalize(
835 args.decay_target, 841 args.decay_target,
836 min(1.0, args.decay_factor * lr) 842 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start))))
837 ) 843 )
838 844
839 loop = partial( 845 loop = partial(