summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py4
-rw-r--r--train_ti.py3
3 files changed, 6 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d2e60ec..0634376 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -593,7 +593,6 @@ def main():
593 dropout=args.tag_dropout, 593 dropout=args.tag_dropout,
594 shuffle=not args.no_tag_shuffle, 594 shuffle=not args.no_tag_shuffle,
595 template_key=args.train_data_template, 595 template_key=args.train_data_template,
596 placeholder_tokens=args.placeholder_tokens,
597 valid_set_size=args.valid_set_size, 596 valid_set_size=args.valid_set_size,
598 train_set_pad=args.train_set_pad, 597 train_set_pad=args.train_set_pad,
599 valid_set_pad=args.valid_set_pad, 598 valid_set_pad=args.valid_set_pad,
@@ -604,9 +603,10 @@ def main():
604 datamodule.setup() 603 datamodule.setup()
605 604
606 num_train_epochs = args.num_train_epochs 605 num_train_epochs = args.num_train_epochs
606 sample_frequency = args.sample_frequency
607 if num_train_epochs is None: 607 if num_train_epochs is None:
608 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 608 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
609 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) 609 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
610 610
611 params_to_optimize = (unet.parameters(), ) 611 params_to_optimize = (unet.parameters(), )
612 if args.train_text_encoder_epochs != 0: 612 if args.train_text_encoder_epochs != 0:
diff --git a/train_lora.py b/train_lora.py
index 7b54ef8..d89b18d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -625,7 +625,6 @@ def main():
625 dropout=args.tag_dropout, 625 dropout=args.tag_dropout,
626 shuffle=not args.no_tag_shuffle, 626 shuffle=not args.no_tag_shuffle,
627 template_key=args.train_data_template, 627 template_key=args.train_data_template,
628 placeholder_tokens=args.placeholder_tokens,
629 valid_set_size=args.valid_set_size, 628 valid_set_size=args.valid_set_size,
630 train_set_pad=args.train_set_pad, 629 train_set_pad=args.train_set_pad,
631 valid_set_pad=args.valid_set_pad, 630 valid_set_pad=args.valid_set_pad,
@@ -636,9 +635,10 @@ def main():
636 datamodule.setup() 635 datamodule.setup()
637 636
638 num_train_epochs = args.num_train_epochs 637 num_train_epochs = args.num_train_epochs
638 sample_frequency = args.sample_frequency
639 if num_train_epochs is None: 639 if num_train_epochs is None:
640 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 640 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
641 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) 641 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
642 642
643 optimizer = create_optimizer( 643 optimizer = create_optimizer(
644 itertools.chain( 644 itertools.chain(
diff --git a/train_ti.py b/train_ti.py
index 7900fbd..b182a72 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -766,9 +766,10 @@ def main():
766 datamodule.setup() 766 datamodule.setup()
767 767
768 num_train_epochs = args.num_train_epochs 768 num_train_epochs = args.num_train_epochs
769 sample_frequency = args.sample_frequency
769 if num_train_epochs is None: 770 if num_train_epochs is None:
770 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 771 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
771 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) 772 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
772 773
773 optimizer = create_optimizer( 774 optimizer = create_optimizer(
774 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 775 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),