diff options
author | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-01 12:35:43 +0200 |
commit | 01eee0cb24f52ca78761b78917959e1c247eae94 (patch) | |
tree | 914c0d3f5b888a4c344b30a861639c8e3d5259dd /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2 textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip |
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py index f74a438..f8dccae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -14,6 +14,7 @@ from accelerate.logging import get_logger | |||
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
16 | from slugify import slugify | 16 | from slugify import slugify |
17 | import transformers | ||
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
@@ -317,7 +318,7 @@ def parse_args(): | |||
317 | "--optimizer", | 318 | "--optimizer", |
318 | type=str, | 319 | type=str, |
319 | default="dadan", | 320 | default="dadan", |
320 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' | 321 | help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]' |
321 | ) | 322 | ) |
322 | parser.add_argument( | 323 | parser.add_argument( |
323 | "--dadaptation_d0", | 324 | "--dadaptation_d0", |
@@ -567,6 +568,19 @@ def main(): | |||
567 | eps=args.adam_epsilon, | 568 | eps=args.adam_epsilon, |
568 | amsgrad=args.adam_amsgrad, | 569 | amsgrad=args.adam_amsgrad, |
569 | ) | 570 | ) |
571 | elif args.optimizer == 'adafactor': | ||
572 | create_optimizer = partial( | ||
573 | transformers.optimization.Adafactor, | ||
574 | beta1=args.adam_beta1, | ||
575 | weight_decay=args.adam_weight_decay, | ||
576 | scale_parameter=True, | ||
577 | relative_step=True, | ||
578 | warmup_init=True, | ||
579 | ) | ||
580 | |||
581 | args.lr_scheduler = "adafactor" | ||
582 | args.lr_min_lr = args.learning_rate | ||
583 | args.learning_rate = None | ||
570 | elif args.optimizer == 'dadam': | 584 | elif args.optimizer == 'dadam': |
571 | try: | 585 | try: |
572 | import dadaptation | 586 | import dadaptation |