diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 3a25efa..4456bd1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -13,6 +13,7 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 15 | from slugify import slugify |
| 16 | import transformers | ||
| 16 | 17 | ||
| 17 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 18 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -305,7 +306,7 @@ def parse_args(): | |||
| 305 | "--optimizer", | 306 | "--optimizer", |
| 306 | type=str, | 307 | type=str, |
| 307 | default="dadan", | 308 | default="dadan", |
| 308 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 309 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
| 309 | ) | 310 | ) |
| 310 | parser.add_argument( | 311 | parser.add_argument( |
| 311 | "--dadaptation_d0", | 312 | "--dadaptation_d0", |
| @@ -535,6 +536,19 @@ def main(): | |||
| 535 | eps=args.adam_epsilon, | 536 | eps=args.adam_epsilon, |
| 536 | amsgrad=args.adam_amsgrad, | 537 | amsgrad=args.adam_amsgrad, |
| 537 | ) | 538 | ) |
| 539 | elif args.optimizer == 'adafactor': | ||
| 540 | create_optimizer = partial( | ||
| 541 | transformers.optimization.Adafactor, | ||
| 542 | beta1=args.adam_beta1, | ||
| 543 | weight_decay=args.adam_weight_decay, | ||
| 544 | scale_parameter=True, | ||
| 545 | relative_step=True, | ||
| 546 | warmup_init=True, | ||
| 547 | ) | ||
| 548 | |||
| 549 | args.lr_scheduler = "adafactor" | ||
| 550 | args.lr_min_lr = args.learning_rate | ||
| 551 | args.learning_rate = None | ||
| 538 | elif args.optimizer == 'dadam': | 552 | elif args.optimizer == 'dadam': |
| 539 | try: | 553 | try: |
| 540 | import dadaptation | 554 | import dadaptation |
