diff options
-rw-r--r-- | train_dreambooth.py | 4 | ||||
-rw-r--r-- | train_lora.py | 4 | ||||
-rw-r--r-- | train_ti.py | 4 |
3 files changed, 9 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c884d2..3a25efa 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -611,7 +611,9 @@ def main(): | |||
611 | num_train_epochs = args.num_train_epochs | 611 | num_train_epochs = args.num_train_epochs |
612 | sample_frequency = args.sample_frequency | 612 | sample_frequency = args.sample_frequency |
613 | if num_train_epochs is None: | 613 | if num_train_epochs is None: |
614 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 614 | num_train_epochs = math.ceil( |
615 | args.num_train_steps / len(datamodule.train_dataset) | ||
616 | ) * args.gradient_accumulation_steps | ||
615 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 617 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
616 | 618 | ||
617 | params_to_optimize = (unet.parameters(), ) | 619 | params_to_optimize = (unet.parameters(), ) |
diff --git a/train_lora.py b/train_lora.py index 59beb09..f74a438 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -643,7 +643,9 @@ def main(): | |||
643 | num_train_epochs = args.num_train_epochs | 643 | num_train_epochs = args.num_train_epochs |
644 | sample_frequency = args.sample_frequency | 644 | sample_frequency = args.sample_frequency |
645 | if num_train_epochs is None: | 645 | if num_train_epochs is None: |
646 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 646 | num_train_epochs = math.ceil( |
647 | args.num_train_steps / len(datamodule.train_dataset) | ||
648 | ) * args.gradient_accumulation_steps | ||
647 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 649 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
648 | 650 | ||
649 | optimizer = create_optimizer( | 651 | optimizer = create_optimizer( |
diff --git a/train_ti.py b/train_ti.py index 83043ad..dd015f9 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -774,7 +774,9 @@ def main(): | |||
774 | num_train_epochs = args.num_train_epochs | 774 | num_train_epochs = args.num_train_epochs |
775 | sample_frequency = args.sample_frequency | 775 | sample_frequency = args.sample_frequency |
776 | if num_train_epochs is None: | 776 | if num_train_epochs is None: |
777 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 777 | num_train_epochs = math.ceil( |
778 | args.num_train_steps / len(datamodule.train_dataset) | ||
779 | ) * args.gradient_accumulation_steps | ||
778 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
779 | 781 | ||
780 | optimizer = create_optimizer( | 782 | optimizer = create_optimizer( |