summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-21 14:18:08 +0100
committerVolpeon <git@volpeon.ink>2023-03-21 14:18:08 +0100
commit744b87831f5e854d86c9f39c131386c3b26e9304 (patch)
tree66226b7a8dfe5403b2dacf2c7397833d981ab3c1 /train_lora.py
parentFixed SNR weighting, re-enabled xformers (diff)
downloadtextual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.gz
textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.tar.bz2
textual-inversion-diff-744b87831f5e854d86c9f39c131386c3b26e9304.zip
Added dadaptation
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index 2a798f3..ce8fb50 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -476,6 +476,34 @@ def main():
476 eps=args.adam_epsilon, 476 eps=args.adam_epsilon,
477 amsgrad=args.adam_amsgrad, 477 amsgrad=args.adam_amsgrad,
478 ) 478 )
479 elif args.optimizer == 'dadam':
480 try:
481 import dadaptation
482 except ImportError:
483 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.")
484
485 create_optimizer = partial(
486 dadaptation.DAdaptAdam,
487 betas=(args.adam_beta1, args.adam_beta2),
488 weight_decay=args.adam_weight_decay,
489 eps=args.adam_epsilon,
490 decouple=True,
491 )
492
493 args.learning_rate = 1.0
494 elif args.optimizer == 'dadan':
495 try:
496 import dadaptation
497 except ImportError:
498 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.")
499
500 create_optimizer = partial(
501 dadaptation.DAdaptAdan,
502 weight_decay=args.adam_weight_decay,
503 eps=args.adam_epsilon,
504 )
505
506 args.learning_rate = 1.0
479 else: 507 else:
480 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 508 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
481 509