summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 12:35:43 +0200
commit01eee0cb24f52ca78761b78917959e1c247eae94 (patch)
tree914c0d3f5b888a4c344b30a861639c8e3d5259dd /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.gz
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.tar.bz2
textual-inversion-diff-01eee0cb24f52ca78761b78917959e1c247eae94.zip
Add support for Adafactor, add TI initializer noise
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index dd015f9..274a1ca 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,6 +12,7 @@ from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15import transformers
15 16
16from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, keyword_filter 18from data.csv import VlpnDataModule, keyword_filter
@@ -75,6 +76,12 @@ def parse_args():
75 help="A token to use as initializer word." 76 help="A token to use as initializer word."
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
79 "--initializer_noise",
80 type=float,
81 default=0,
82 help="Noise to apply to the initializer word"
83 )
84 parser.add_argument(
78 "--alias_tokens", 85 "--alias_tokens",
79 type=str, 86 type=str,
80 nargs='*', 87 nargs='*',
@@ -323,7 +330,7 @@ def parse_args():
323 "--optimizer", 330 "--optimizer",
324 type=str, 331 type=str,
325 default="dadan", 332 default="dadan",
326 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan"]' 333 help='Optimizer to use ["adam", "adam8bit", "dadam", "dadan", "adafactor"]'
327 ) 334 )
328 parser.add_argument( 335 parser.add_argument(
329 "--dadaptation_d0", 336 "--dadaptation_d0",
@@ -659,6 +666,19 @@ def main():
659 eps=args.adam_epsilon, 666 eps=args.adam_epsilon,
660 amsgrad=args.adam_amsgrad, 667 amsgrad=args.adam_amsgrad,
661 ) 668 )
669 elif args.optimizer == 'adafactor':
670 create_optimizer = partial(
671 transformers.optimization.Adafactor,
672 beta1=args.adam_beta1,
673 weight_decay=args.adam_weight_decay,
674 scale_parameter=True,
675 relative_step=True,
676 warmup_init=True,
677 )
678
679 args.lr_scheduler = "adafactor"
680 args.lr_min_lr = args.learning_rate
681 args.learning_rate = None
662 elif args.optimizer == 'dadam': 682 elif args.optimizer == 'dadam':
663 try: 683 try:
664 import dadaptation 684 import dadaptation
@@ -739,7 +759,8 @@ def main():
739 embeddings=embeddings, 759 embeddings=embeddings,
740 placeholder_tokens=placeholder_tokens, 760 placeholder_tokens=placeholder_tokens,
741 initializer_tokens=initializer_tokens, 761 initializer_tokens=initializer_tokens,
742 num_vectors=num_vectors 762 num_vectors=num_vectors,
763 initializer_noise=args.initializer_noise,
743 ) 764 )
744 765
745 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 766 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))