summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml2
-rw-r--r--train_ti.py4
-rw-r--r--training/functional.py2
3 files changed, 2 insertions, 6 deletions
diff --git a/environment.yaml b/environment.yaml
index 42b568f..9c12a0b 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,6 +11,7 @@ dependencies:
11 - python=3.10.8 11 - python=3.10.8
12 - pytorch=2.0.0=*cuda11.8* 12 - pytorch=2.0.0=*cuda11.8*
13 - torchvision=0.15.0 13 - torchvision=0.15.0
14 - xformers=0.0.17.dev481
14 - pip: 15 - pip:
15 - -e . 16 - -e .
16 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
@@ -24,4 +25,3 @@ dependencies:
24 - test-tube>=0.7.5 25 - test-tube>=0.7.5
25 - transformers==4.27.1 26 - transformers==4.27.1
26 - triton==2.0.0 27 - triton==2.0.0
27 - xformers==0.0.17.dev480
diff --git a/train_ti.py b/train_ti.py
index 036c288..7aeff7c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -620,8 +620,6 @@ def main():
620 eps=args.adam_epsilon, 620 eps=args.adam_epsilon,
621 decouple=True, 621 decouple=True,
622 ) 622 )
623
624 args.learning_rate = 1.0
625 elif args.optimizer == 'dadan': 623 elif args.optimizer == 'dadan':
626 try: 624 try:
627 import dadaptation 625 import dadaptation
@@ -633,8 +631,6 @@ def main():
633 weight_decay=args.adam_weight_decay, 631 weight_decay=args.adam_weight_decay,
634 eps=args.adam_epsilon, 632 eps=args.adam_epsilon,
635 ) 633 )
636
637 args.learning_rate = 1.0
638 else: 634 else:
639 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 635 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
640 636
diff --git a/training/functional.py b/training/functional.py
index 77f056e..ebb48ab 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -429,7 +429,7 @@ def train_loop(
429 try: 429 try:
430 import dadaptation 430 import dadaptation
431 431
432 isDadaptation = isinstance(optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) 432 isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan))
433 except ImportError: 433 except ImportError:
434 pass 434 pass
435 435