diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
commit | 01eee0cb24f52ca78761b78917959e1c247eae94 (patch) | |
tree | 914c0d3f5b888a4c344b30a861639c8e3d5259dd /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2 textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip |
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -13,6 +13,7 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 15 | from slugify import slugify |
16 | import transformers | ||
16 | 17 | ||
17 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
18 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
@@ -305,7 +306,7 @@ def parse_args(): | |||
305 | "--optimizer", | 306 | "--optimizer", |
306 | type=str, | 307 | type=str, |
307 | default="dadan", | 308 | default="dadan", |
308 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
309 | ) | 310 | ) |
310 | parser.add_argument( | 311 | parser.add_argument( |
311 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
@@ -535,6 +536,19 @@ def main(): | |||
535 | eps=args.adam_epsilon, | 536 | eps=args.adam_epsilon, |
536 | amsgrad=args.adam_amsgrad, | 537 | amsgrad=args.adam_amsgrad, |
537 | ) | 538 | ) |
539 | elif args.optimizer == 'adafactor': | ||
540 | create_optimizer = partial( | ||
541 | transformers.optimization.Adafactor, | ||
542 | beta1=args.adam_beta1, | ||
543 | weight_decay=args.adam_weight_decay, | ||
544 | scale_parameter=True, | ||
545 | relative_step=True, | ||
546 | warmup_init=True, | ||
547 | ) | ||
548 | |||
549 | args.lr_scheduler = "adafactor" | ||
550 | args.lr_min_lr = args.learning_rate | ||
551 | args.learning_rate = None | ||
538 | elif args.optimizer == 'dadam': | 552 | elif args.optimizer == 'dadam': |
539 | try: | 553 | try: |
540 | import dadaptation | 554 | import dadaptation |