summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-31 22:17:26 +0200
committerVolpeon <git@volpeon.ink>2023-03-31 22:17:26 +0200
commitbf24028a869c849a29d23b05db0284a158d201f0 (patch)
treedc90ed7d2c6da10ded6f2c98d0f504f9372b6683
parentUpdate (diff)
downloadtextual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.tar.gz
textual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.tar.bz2
textual-inversion-diff-bf24028a869c849a29d23b05db0284a158d201f0.zip
Update
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py4
-rw-r--r--train_ti.py4
3 files changed, 9 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2c884d2..3a25efa 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -611,7 +611,9 @@ def main():
611 num_train_epochs = args.num_train_epochs 611 num_train_epochs = args.num_train_epochs
612 sample_frequency = args.sample_frequency 612 sample_frequency = args.sample_frequency
613 if num_train_epochs is None: 613 if num_train_epochs is None:
614 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 614 num_train_epochs = math.ceil(
615 args.num_train_steps / len(datamodule.train_dataset)
616 ) * args.gradient_accumulation_steps
615 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 617 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
616 618
617 params_to_optimize = (unet.parameters(), ) 619 params_to_optimize = (unet.parameters(), )
diff --git a/train_lora.py b/train_lora.py
index 59beb09..f74a438 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -643,7 +643,9 @@ def main():
643 num_train_epochs = args.num_train_epochs 643 num_train_epochs = args.num_train_epochs
644 sample_frequency = args.sample_frequency 644 sample_frequency = args.sample_frequency
645 if num_train_epochs is None: 645 if num_train_epochs is None:
646 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 646 num_train_epochs = math.ceil(
647 args.num_train_steps / len(datamodule.train_dataset)
648 ) * args.gradient_accumulation_steps
647 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 649 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
648 650
649 optimizer = create_optimizer( 651 optimizer = create_optimizer(
diff --git a/train_ti.py b/train_ti.py
index 83043ad..dd015f9 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -774,7 +774,9 @@ def main():
774 num_train_epochs = args.num_train_epochs 774 num_train_epochs = args.num_train_epochs
775 sample_frequency = args.sample_frequency 775 sample_frequency = args.sample_frequency
776 if num_train_epochs is None: 776 if num_train_epochs is None:
777 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 777 num_train_epochs = math.ceil(
778 args.num_train_steps / len(datamodule.train_dataset)
779 ) * args.gradient_accumulation_steps
778 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
779 781
780 optimizer = create_optimizer( 782 optimizer = create_optimizer(