summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml6
-rw-r--r--train_dreambooth.py3
-rw-r--r--train_lora.py3
-rw-r--r--train_ti.py3
4 files changed, 9 insertions, 6 deletions
diff --git a/environment.yaml b/environment.yaml
index 9c12a0b..8868532 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,17 +11,17 @@ dependencies:
11 - python=3.10.8 11 - python=3.10.8
12 - pytorch=2.0.0=*cuda11.8* 12 - pytorch=2.0.0=*cuda11.8*
13 - torchvision=0.15.0 13 - torchvision=0.15.0
14 - xformers=0.0.17.dev481 14 - xformers=0.0.18.dev498
15 - pip: 15 - pip:
16 - -e . 16 - -e .
17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
18 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 18 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
19 - accelerate==0.17.1 19 - accelerate==0.17.1
20 - bitsandbytes==0.37.1 20 - bitsandbytes==0.37.2
21 - peft==0.2.0 21 - peft==0.2.0
22 - python-slugify>=6.1.2 22 - python-slugify>=6.1.2
23 - safetensors==0.3.0 23 - safetensors==0.3.0
24 - setuptools==65.6.3 24 - setuptools==65.6.3
25 - test-tube>=0.7.5 25 - test-tube>=0.7.5
26 - transformers==4.27.1 26 - transformers==4.27.1
27 - triton==2.0.0 27 - triton==2.0.0.post1
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e3c8525..f1dca7f 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -598,7 +598,8 @@ def main():
598 num_train_epochs = args.num_train_epochs 598 num_train_epochs = args.num_train_epochs
599 599
600 if num_train_epochs is None: 600 if num_train_epochs is None:
601 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 601 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size
602 num_train_epochs = math.ceil(args.num_train_steps / num_images)
602 603
603 params_to_optimize = (unet.parameters(), ) 604 params_to_optimize = (unet.parameters(), )
604 if args.train_text_encoder_epochs != 0: 605 if args.train_text_encoder_epochs != 0:
diff --git a/train_lora.py b/train_lora.py
index 6f8644b..9975462 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -630,7 +630,8 @@ def main():
630 num_train_epochs = args.num_train_epochs 630 num_train_epochs = args.num_train_epochs
631 631
632 if num_train_epochs is None: 632 if num_train_epochs is None:
633 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 633 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size
634 num_train_epochs = math.ceil(args.num_train_steps / num_images)
634 635
635 optimizer = create_optimizer( 636 optimizer = create_optimizer(
636 itertools.chain( 637 itertools.chain(
diff --git a/train_ti.py b/train_ti.py
index 9c4ad93..b7ea5f3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -761,7 +761,8 @@ def main():
761 num_train_epochs = args.num_train_epochs 761 num_train_epochs = args.num_train_epochs
762 762
763 if num_train_epochs is None: 763 if num_train_epochs is None:
764 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 764 num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size
765 num_train_epochs = math.ceil(args.num_train_steps / num_images)
765 766
766 optimizer = create_optimizer( 767 optimizer = create_optimizer(
767 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 768 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),