diff options
-rw-r--r-- | train_dreambooth.py | 2 | ||||
-rw-r--r-- | train_ti.py | 10 |
2 files changed, 9 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index da3a075..0fe590f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -907,7 +907,7 @@ def main(): | |||
907 | on_before_optimize=on_before_optimize, | 907 | on_before_optimize=on_before_optimize, |
908 | on_after_optimize=on_after_optimize, | 908 | on_after_optimize=on_after_optimize, |
909 | ) | 909 | ) |
910 | lr_finder.run(num_epochs=100, end_lr=1e3) | 910 | lr_finder.run(num_epochs=100, end_lr=1e2) |
911 | 911 | ||
912 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 912 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
913 | plt.close() | 913 | plt.close() |
diff --git a/train_ti.py b/train_ti.py index 3b7e3b1..e18ee38 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -415,11 +415,17 @@ def parse_args(): | |||
415 | ) | 415 | ) |
416 | parser.add_argument( | 416 | parser.add_argument( |
417 | "--decay_factor", | 417 | "--decay_factor", |
418 | default=100, | 418 | default=1, |
419 | type=float, | 419 | type=float, |
420 | help="Embedding decay factor." | 420 | help="Embedding decay factor." |
421 | ) | 421 | ) |
422 | parser.add_argument( | 422 | parser.add_argument( |
423 | "--decay_start", | ||
424 | default=1e-4, | ||
425 | type=float, | ||
426 | help="Embedding decay start offset." | ||
427 | ) | ||
428 | parser.add_argument( | ||
423 | "--noise_timesteps", | 429 | "--noise_timesteps", |
424 | type=int, | 430 | type=int, |
425 | default=1000, | 431 | default=1000, |
@@ -833,7 +839,7 @@ def main(): | |||
833 | def on_after_optimize(lr: float): | 839 | def on_after_optimize(lr: float): |
834 | text_encoder.text_model.embeddings.normalize( | 840 | text_encoder.text_model.embeddings.normalize( |
835 | args.decay_target, | 841 | args.decay_target, |
836 | min(1.0, args.decay_factor * lr) | 842 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) |
837 | ) | 843 | ) |
838 | 844 | ||
839 | loop = partial( | 845 | loop = partial( |