From be3e05e47cded8487aaa787c54aa74770f9dcac8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 31 Mar 2023 09:34:55 +0200 Subject: Fix --- environment.yaml | 6 +++--- train_dreambooth.py | 3 ++- train_lora.py | 3 ++- train_ti.py | 3 ++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/environment.yaml b/environment.yaml index 9c12a0b..8868532 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,17 +11,17 @@ dependencies: - python=3.10.8 - pytorch=2.0.0=*cuda11.8* - torchvision=0.15.0 - - xformers=0.0.17.dev481 + - xformers=0.0.18.dev498 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation - accelerate==0.17.1 - - bitsandbytes==0.37.1 + - bitsandbytes==0.37.2 - peft==0.2.0 - python-slugify>=6.1.2 - safetensors==0.3.0 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.27.1 - - triton==2.0.0 + - triton==2.0.0.post1 diff --git a/train_dreambooth.py b/train_dreambooth.py index e3c8525..f1dca7f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -598,7 +598,8 @@ def main(): num_train_epochs = args.num_train_epochs if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size + num_train_epochs = math.ceil(args.num_train_steps / num_images) params_to_optimize = (unet.parameters(), ) if args.train_text_encoder_epochs != 0: diff --git a/train_lora.py b/train_lora.py index 6f8644b..9975462 100644 --- a/train_lora.py +++ b/train_lora.py @@ -630,7 +630,8 @@ def main(): num_train_epochs = args.num_train_epochs if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size + num_train_epochs = math.ceil(args.num_train_steps / num_images) optimizer = create_optimizer( itertools.chain( diff --git a/train_ti.py b/train_ti.py index 9c4ad93..b7ea5f3 100644 --- a/train_ti.py +++ b/train_ti.py @@ -761,7 +761,8 @@ def main(): num_train_epochs = args.num_train_epochs if num_train_epochs is None: - num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size + num_train_epochs = math.ceil(args.num_train_steps / num_images) optimizer = create_optimizer( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), -- cgit v1.2.3-54-g00ecf