diff options
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 10 |
2 files changed, 9 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index da3a075..0fe590f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -907,7 +907,7 @@ def main(): | |||
| 907 | on_before_optimize=on_before_optimize, | 907 | on_before_optimize=on_before_optimize, |
| 908 | on_after_optimize=on_after_optimize, | 908 | on_after_optimize=on_after_optimize, |
| 909 | ) | 909 | ) |
| 910 | lr_finder.run(num_epochs=100, end_lr=1e3) | 910 | lr_finder.run(num_epochs=100, end_lr=1e2) |
| 911 | 911 | ||
| 912 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 912 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 913 | plt.close() | 913 | plt.close() |
diff --git a/train_ti.py b/train_ti.py index 3b7e3b1..e18ee38 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -415,11 +415,17 @@ def parse_args(): | |||
| 415 | ) | 415 | ) |
| 416 | parser.add_argument( | 416 | parser.add_argument( |
| 417 | "--decay_factor", | 417 | "--decay_factor", |
| 418 | default=100, | 418 | default=1, |
| 419 | type=float, | 419 | type=float, |
| 420 | help="Embedding decay factor." | 420 | help="Embedding decay factor." |
| 421 | ) | 421 | ) |
| 422 | parser.add_argument( | 422 | parser.add_argument( |
| 423 | "--decay_start", | ||
| 424 | default=1e-4, | ||
| 425 | type=float, | ||
| 426 | help="Embedding decay start offset." | ||
| 427 | ) | ||
| 428 | parser.add_argument( | ||
| 423 | "--noise_timesteps", | 429 | "--noise_timesteps", |
| 424 | type=int, | 430 | type=int, |
| 425 | default=1000, | 431 | default=1000, |
| @@ -833,7 +839,7 @@ def main(): | |||
| 833 | def on_after_optimize(lr: float): | 839 | def on_after_optimize(lr: float): |
| 834 | text_encoder.text_model.embeddings.normalize( | 840 | text_encoder.text_model.embeddings.normalize( |
| 835 | args.decay_target, | 841 | args.decay_target, |
| 836 | min(1.0, args.decay_factor * lr) | 842 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) |
| 837 | ) | 843 | ) |
| 838 | 844 | ||
| 839 | loop = partial( | 845 | loop = partial( |
