summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py8
-rw-r--r--train_dreambooth.py17
-rw-r--r--train_lora.py17
-rw-r--r--train_ti.py30
4 files changed, 49 insertions, 23 deletions
diff --git a/data/csv.py b/data/csv.py
index 9770bec..c00ea07 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple):
143 143
144def keyword_filter( 144def keyword_filter(
145 placeholder_tokens: Optional[list[str]], 145 placeholder_tokens: Optional[list[str]],
146 collection: Optional[list[str]], 146 collections: Optional[list[str]],
147 exclude_collections: Optional[list[str]], 147 exclude_collections: Optional[list[str]],
148 item: VlpnDataItem 148 item: VlpnDataItem
149): 149):
@@ -152,11 +152,15 @@ def keyword_filter(
152 for keyword in placeholder_tokens 152 for keyword in placeholder_tokens
153 for part in item.prompt 153 for part in item.prompt
154 ) 154 )
155 cond2 = collection is None or collection in item.collection 155 cond2 = collections is None or any(
156 collection in item.collection
157 for collection in collections
158 )
156 cond3 = exclude_collections is None or not any( 159 cond3 = exclude_collections is None or not any(
157 collection in item.collection 160 collection in item.collection
158 for collection in exclude_collections 161 for collection in exclude_collections
159 ) 162 )
163
160 return cond1 and cond2 and cond3 164 return cond1 and cond2 and cond3
161 165
162 166
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9345797..acb8287 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -4,6 +4,7 @@ import logging
4import itertools 4import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math
7 8
8import torch 9import torch
9import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -189,13 +190,12 @@ def parse_args():
189 parser.add_argument( 190 parser.add_argument(
190 "--num_train_epochs", 191 "--num_train_epochs",
191 type=int, 192 type=int,
192 default=100 193 default=None
193 ) 194 )
194 parser.add_argument( 195 parser.add_argument(
195 "--max_train_steps", 196 "--num_train_steps",
196 type=int, 197 type=int,
197 default=None, 198 default=2000
198 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
199 ) 199 )
200 parser.add_argument( 200 parser.add_argument(
201 "--gradient_accumulation_steps", 201 "--gradient_accumulation_steps",
@@ -595,6 +595,11 @@ def main():
595 ) 595 )
596 datamodule.setup() 596 datamodule.setup()
597 597
598 num_train_epochs = args.num_train_epochs
599
600 if num_train_epochs is None:
601 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
602
598 params_to_optimize = (unet.parameters(), ) 603 params_to_optimize = (unet.parameters(), )
599 if args.train_text_encoder_epochs != 0: 604 if args.train_text_encoder_epochs != 0:
600 params_to_optimize += ( 605 params_to_optimize += (
@@ -619,7 +624,7 @@ def main():
619 annealing_exp=args.lr_annealing_exp, 624 annealing_exp=args.lr_annealing_exp,
620 cycles=args.lr_cycles, 625 cycles=args.lr_cycles,
621 end_lr=1e2, 626 end_lr=1e2,
622 train_epochs=args.num_train_epochs, 627 train_epochs=num_train_epochs,
623 warmup_epochs=args.lr_warmup_epochs, 628 warmup_epochs=args.lr_warmup_epochs,
624 ) 629 )
625 630
@@ -631,7 +636,7 @@ def main():
631 seed=args.seed, 636 seed=args.seed,
632 optimizer=optimizer, 637 optimizer=optimizer,
633 lr_scheduler=lr_scheduler, 638 lr_scheduler=lr_scheduler,
634 num_train_epochs=args.num_train_epochs, 639 num_train_epochs=num_train_epochs,
635 gradient_accumulation_steps=args.gradient_accumulation_steps, 640 gradient_accumulation_steps=args.gradient_accumulation_steps,
636 sample_frequency=args.sample_frequency, 641 sample_frequency=args.sample_frequency,
637 offset_noise_strength=args.offset_noise_strength, 642 offset_noise_strength=args.offset_noise_strength,
diff --git a/train_lora.py b/train_lora.py
index 7ecddf0..a9c6e52 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -4,6 +4,7 @@ import logging
4import itertools 4import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math
7 8
8import torch 9import torch
9import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -178,13 +179,12 @@ def parse_args():
178 parser.add_argument( 179 parser.add_argument(
179 "--num_train_epochs", 180 "--num_train_epochs",
180 type=int, 181 type=int,
181 default=100 182 default=None
182 ) 183 )
183 parser.add_argument( 184 parser.add_argument(
184 "--max_train_steps", 185 "--num_train_steps",
185 type=int, 186 type=int,
186 default=None, 187 default=2000
187 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
188 ) 188 )
189 parser.add_argument( 189 parser.add_argument(
190 "--gradient_accumulation_steps", 190 "--gradient_accumulation_steps",
@@ -627,6 +627,11 @@ def main():
627 ) 627 )
628 datamodule.setup() 628 datamodule.setup()
629 629
630 num_train_epochs = args.num_train_epochs
631
632 if num_train_epochs is None:
633 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
634
630 optimizer = create_optimizer( 635 optimizer = create_optimizer(
631 itertools.chain( 636 itertools.chain(
632 unet.parameters(), 637 unet.parameters(),
@@ -647,7 +652,7 @@ def main():
647 annealing_exp=args.lr_annealing_exp, 652 annealing_exp=args.lr_annealing_exp,
648 cycles=args.lr_cycles, 653 cycles=args.lr_cycles,
649 end_lr=1e2, 654 end_lr=1e2,
650 train_epochs=args.num_train_epochs, 655 train_epochs=num_train_epochs,
651 warmup_epochs=args.lr_warmup_epochs, 656 warmup_epochs=args.lr_warmup_epochs,
652 ) 657 )
653 658
@@ -659,7 +664,7 @@ def main():
659 seed=args.seed, 664 seed=args.seed,
660 optimizer=optimizer, 665 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 666 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs, 667 num_train_epochs=num_train_epochs,
663 gradient_accumulation_steps=args.gradient_accumulation_steps, 668 gradient_accumulation_steps=args.gradient_accumulation_steps,
664 sample_frequency=args.sample_frequency, 669 sample_frequency=args.sample_frequency,
665 offset_noise_strength=args.offset_noise_strength, 670 offset_noise_strength=args.offset_noise_strength,
diff --git a/train_ti.py b/train_ti.py
index e4fd464..7bcc72f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,6 +3,7 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6import math
6 7
7import torch 8import torch
8import torch.utils.checkpoint 9import torch.utils.checkpoint
@@ -207,7 +208,12 @@ def parse_args():
207 parser.add_argument( 208 parser.add_argument(
208 "--num_train_epochs", 209 "--num_train_epochs",
209 type=int, 210 type=int,
210 default=100 211 default=None
212 )
213 parser.add_argument(
214 "--num_train_steps",
215 type=int,
216 default=2000
211 ) 217 )
212 parser.add_argument( 218 parser.add_argument(
213 "--gradient_accumulation_steps", 219 "--gradient_accumulation_steps",
@@ -513,13 +519,13 @@ def parse_args():
513 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 519 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
514 raise ValueError("--alias_tokens must be a list with an even number of items") 520 raise ValueError("--alias_tokens must be a list with an even number of items")
515 521
516 args.alias_tokens += [
517 item
518 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
519 for item in pair
520 ]
521
522 if args.sequential: 522 if args.sequential:
523 args.alias_tokens += [
524 item
525 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
526 for item in pair
527 ]
528
523 if isinstance(args.train_data_template, str): 529 if isinstance(args.train_data_template, str):
524 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 530 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
525 531
@@ -607,6 +613,7 @@ def main():
607 raise ValueError("--embeddings_dir must point to an existing directory") 613 raise ValueError("--embeddings_dir must point to an existing directory")
608 614
609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 615 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
616 embeddings.persist()
610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 617 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
611 618
612 if args.scale_lr: 619 if args.scale_lr:
@@ -682,7 +689,6 @@ def main():
682 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 689 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
683 no_val=args.valid_set_size == 0, 690 no_val=args.valid_set_size == 0,
684 strategy=textual_inversion_strategy, 691 strategy=textual_inversion_strategy,
685 num_train_epochs=args.num_train_epochs,
686 gradient_accumulation_steps=args.gradient_accumulation_steps, 692 gradient_accumulation_steps=args.gradient_accumulation_steps,
687 sample_frequency=args.sample_frequency, 693 sample_frequency=args.sample_frequency,
688 checkpoint_frequency=args.checkpoint_frequency, 694 checkpoint_frequency=args.checkpoint_frequency,
@@ -752,6 +758,11 @@ def main():
752 ) 758 )
753 datamodule.setup() 759 datamodule.setup()
754 760
761 num_train_epochs = args.num_train_epochs
762
763 if num_train_epochs is None:
764 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
765
755 optimizer = create_optimizer( 766 optimizer = create_optimizer(
756 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 767 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
757 lr=args.learning_rate, 768 lr=args.learning_rate,
@@ -769,7 +780,7 @@ def main():
769 annealing_exp=args.lr_annealing_exp, 780 annealing_exp=args.lr_annealing_exp,
770 cycles=args.lr_cycles, 781 cycles=args.lr_cycles,
771 end_lr=1e3, 782 end_lr=1e3,
772 train_epochs=args.num_train_epochs, 783 train_epochs=num_train_epochs,
773 warmup_epochs=args.lr_warmup_epochs, 784 warmup_epochs=args.lr_warmup_epochs,
774 ) 785 )
775 786
@@ -779,6 +790,7 @@ def main():
779 val_dataloader=datamodule.val_dataloader, 790 val_dataloader=datamodule.val_dataloader,
780 optimizer=optimizer, 791 optimizer=optimizer,
781 lr_scheduler=lr_scheduler, 792 lr_scheduler=lr_scheduler,
793 num_train_epochs=num_train_epochs,
782 # -- 794 # --
783 sample_output_dir=sample_output_dir, 795 sample_output_dir=sample_output_dir,
784 placeholder_tokens=placeholder_tokens, 796 placeholder_tokens=placeholder_tokens,