diff options
| -rw-r--r-- | data/csv.py | 8 | ||||
| -rw-r--r-- | train_dreambooth.py | 17 | ||||
| -rw-r--r-- | train_lora.py | 17 | ||||
| -rw-r--r-- | train_ti.py | 30 |
4 files changed, 49 insertions, 23 deletions
diff --git a/data/csv.py b/data/csv.py index 9770bec..c00ea07 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple): | |||
| 143 | 143 | ||
| 144 | def keyword_filter( | 144 | def keyword_filter( |
| 145 | placeholder_tokens: Optional[list[str]], | 145 | placeholder_tokens: Optional[list[str]], |
| 146 | collection: Optional[list[str]], | 146 | collections: Optional[list[str]], |
| 147 | exclude_collections: Optional[list[str]], | 147 | exclude_collections: Optional[list[str]], |
| 148 | item: VlpnDataItem | 148 | item: VlpnDataItem |
| 149 | ): | 149 | ): |
| @@ -152,11 +152,15 @@ def keyword_filter( | |||
| 152 | for keyword in placeholder_tokens | 152 | for keyword in placeholder_tokens |
| 153 | for part in item.prompt | 153 | for part in item.prompt |
| 154 | ) | 154 | ) |
| 155 | cond2 = collection is None or collection in item.collection | 155 | cond2 = collections is None or any( |
| 156 | collection in item.collection | ||
| 157 | for collection in collections | ||
| 158 | ) | ||
| 156 | cond3 = exclude_collections is None or not any( | 159 | cond3 = exclude_collections is None or not any( |
| 157 | collection in item.collection | 160 | collection in item.collection |
| 158 | for collection in exclude_collections | 161 | for collection in exclude_collections |
| 159 | ) | 162 | ) |
| 163 | |||
| 160 | return cond1 and cond2 and cond3 | 164 | return cond1 and cond2 and cond3 |
| 161 | 165 | ||
| 162 | 166 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9345797..acb8287 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -4,6 +4,7 @@ import logging | |||
| 4 | import itertools | 4 | import itertools |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | ||
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| 9 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| @@ -189,13 +190,12 @@ def parse_args(): | |||
| 189 | parser.add_argument( | 190 | parser.add_argument( |
| 190 | "--num_train_epochs", | 191 | "--num_train_epochs", |
| 191 | type=int, | 192 | type=int, |
| 192 | default=100 | 193 | default=None |
| 193 | ) | 194 | ) |
| 194 | parser.add_argument( | 195 | parser.add_argument( |
| 195 | "--max_train_steps", | 196 | "--num_train_steps", |
| 196 | type=int, | 197 | type=int, |
| 197 | default=None, | 198 | default=2000 |
| 198 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 199 | ) | 199 | ) |
| 200 | parser.add_argument( | 200 | parser.add_argument( |
| 201 | "--gradient_accumulation_steps", | 201 | "--gradient_accumulation_steps", |
| @@ -595,6 +595,11 @@ def main(): | |||
| 595 | ) | 595 | ) |
| 596 | datamodule.setup() | 596 | datamodule.setup() |
| 597 | 597 | ||
| 598 | num_train_epochs = args.num_train_epochs | ||
| 599 | |||
| 600 | if num_train_epochs is None: | ||
| 601 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
| 602 | |||
| 598 | params_to_optimize = (unet.parameters(), ) | 603 | params_to_optimize = (unet.parameters(), ) |
| 599 | if args.train_text_encoder_epochs != 0: | 604 | if args.train_text_encoder_epochs != 0: |
| 600 | params_to_optimize += ( | 605 | params_to_optimize += ( |
| @@ -619,7 +624,7 @@ def main(): | |||
| 619 | annealing_exp=args.lr_annealing_exp, | 624 | annealing_exp=args.lr_annealing_exp, |
| 620 | cycles=args.lr_cycles, | 625 | cycles=args.lr_cycles, |
| 621 | end_lr=1e2, | 626 | end_lr=1e2, |
| 622 | train_epochs=args.num_train_epochs, | 627 | train_epochs=num_train_epochs, |
| 623 | warmup_epochs=args.lr_warmup_epochs, | 628 | warmup_epochs=args.lr_warmup_epochs, |
| 624 | ) | 629 | ) |
| 625 | 630 | ||
| @@ -631,7 +636,7 @@ def main(): | |||
| 631 | seed=args.seed, | 636 | seed=args.seed, |
| 632 | optimizer=optimizer, | 637 | optimizer=optimizer, |
| 633 | lr_scheduler=lr_scheduler, | 638 | lr_scheduler=lr_scheduler, |
| 634 | num_train_epochs=args.num_train_epochs, | 639 | num_train_epochs=num_train_epochs, |
| 635 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 640 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 636 | sample_frequency=args.sample_frequency, | 641 | sample_frequency=args.sample_frequency, |
| 637 | offset_noise_strength=args.offset_noise_strength, | 642 | offset_noise_strength=args.offset_noise_strength, |
diff --git a/train_lora.py b/train_lora.py index 7ecddf0..a9c6e52 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -4,6 +4,7 @@ import logging | |||
| 4 | import itertools | 4 | import itertools |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | ||
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| 9 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| @@ -178,13 +179,12 @@ def parse_args(): | |||
| 178 | parser.add_argument( | 179 | parser.add_argument( |
| 179 | "--num_train_epochs", | 180 | "--num_train_epochs", |
| 180 | type=int, | 181 | type=int, |
| 181 | default=100 | 182 | default=None |
| 182 | ) | 183 | ) |
| 183 | parser.add_argument( | 184 | parser.add_argument( |
| 184 | "--max_train_steps", | 185 | "--num_train_steps", |
| 185 | type=int, | 186 | type=int, |
| 186 | default=None, | 187 | default=2000 |
| 187 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 188 | ) | 188 | ) |
| 189 | parser.add_argument( | 189 | parser.add_argument( |
| 190 | "--gradient_accumulation_steps", | 190 | "--gradient_accumulation_steps", |
| @@ -627,6 +627,11 @@ def main(): | |||
| 627 | ) | 627 | ) |
| 628 | datamodule.setup() | 628 | datamodule.setup() |
| 629 | 629 | ||
| 630 | num_train_epochs = args.num_train_epochs | ||
| 631 | |||
| 632 | if num_train_epochs is None: | ||
| 633 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
| 634 | |||
| 630 | optimizer = create_optimizer( | 635 | optimizer = create_optimizer( |
| 631 | itertools.chain( | 636 | itertools.chain( |
| 632 | unet.parameters(), | 637 | unet.parameters(), |
| @@ -647,7 +652,7 @@ def main(): | |||
| 647 | annealing_exp=args.lr_annealing_exp, | 652 | annealing_exp=args.lr_annealing_exp, |
| 648 | cycles=args.lr_cycles, | 653 | cycles=args.lr_cycles, |
| 649 | end_lr=1e2, | 654 | end_lr=1e2, |
| 650 | train_epochs=args.num_train_epochs, | 655 | train_epochs=num_train_epochs, |
| 651 | warmup_epochs=args.lr_warmup_epochs, | 656 | warmup_epochs=args.lr_warmup_epochs, |
| 652 | ) | 657 | ) |
| 653 | 658 | ||
| @@ -659,7 +664,7 @@ def main(): | |||
| 659 | seed=args.seed, | 664 | seed=args.seed, |
| 660 | optimizer=optimizer, | 665 | optimizer=optimizer, |
| 661 | lr_scheduler=lr_scheduler, | 666 | lr_scheduler=lr_scheduler, |
| 662 | num_train_epochs=args.num_train_epochs, | 667 | num_train_epochs=num_train_epochs, |
| 663 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 668 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 664 | sample_frequency=args.sample_frequency, | 669 | sample_frequency=args.sample_frequency, |
| 665 | offset_noise_strength=args.offset_noise_strength, | 670 | offset_noise_strength=args.offset_noise_strength, |
diff --git a/train_ti.py b/train_ti.py index e4fd464..7bcc72f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,6 +3,7 @@ import datetime | |||
| 3 | import logging | 3 | import logging |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | import math | ||
| 6 | 7 | ||
| 7 | import torch | 8 | import torch |
| 8 | import torch.utils.checkpoint | 9 | import torch.utils.checkpoint |
| @@ -207,7 +208,12 @@ def parse_args(): | |||
| 207 | parser.add_argument( | 208 | parser.add_argument( |
| 208 | "--num_train_epochs", | 209 | "--num_train_epochs", |
| 209 | type=int, | 210 | type=int, |
| 210 | default=100 | 211 | default=None |
| 212 | ) | ||
| 213 | parser.add_argument( | ||
| 214 | "--num_train_steps", | ||
| 215 | type=int, | ||
| 216 | default=2000 | ||
| 211 | ) | 217 | ) |
| 212 | parser.add_argument( | 218 | parser.add_argument( |
| 213 | "--gradient_accumulation_steps", | 219 | "--gradient_accumulation_steps", |
| @@ -513,13 +519,13 @@ def parse_args(): | |||
| 513 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | 519 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: |
| 514 | raise ValueError("--alias_tokens must be a list with an even number of items") | 520 | raise ValueError("--alias_tokens must be a list with an even number of items") |
| 515 | 521 | ||
| 516 | args.alias_tokens += [ | ||
| 517 | item | ||
| 518 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
| 519 | for item in pair | ||
| 520 | ] | ||
| 521 | |||
| 522 | if args.sequential: | 522 | if args.sequential: |
| 523 | args.alias_tokens += [ | ||
| 524 | item | ||
| 525 | for pair in zip(args.placeholder_tokens, args.initializer_tokens) | ||
| 526 | for item in pair | ||
| 527 | ] | ||
| 528 | |||
| 523 | if isinstance(args.train_data_template, str): | 529 | if isinstance(args.train_data_template, str): |
| 524 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 530 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
| 525 | 531 | ||
| @@ -607,6 +613,7 @@ def main(): | |||
| 607 | raise ValueError("--embeddings_dir must point to an existing directory") | 613 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 608 | 614 | ||
| 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 615 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 616 | embeddings.persist() | ||
| 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 611 | 618 | ||
| 612 | if args.scale_lr: | 619 | if args.scale_lr: |
| @@ -682,7 +689,6 @@ def main(): | |||
| 682 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 689 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 683 | no_val=args.valid_set_size == 0, | 690 | no_val=args.valid_set_size == 0, |
| 684 | strategy=textual_inversion_strategy, | 691 | strategy=textual_inversion_strategy, |
| 685 | num_train_epochs=args.num_train_epochs, | ||
| 686 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 692 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 687 | sample_frequency=args.sample_frequency, | 693 | sample_frequency=args.sample_frequency, |
| 688 | checkpoint_frequency=args.checkpoint_frequency, | 694 | checkpoint_frequency=args.checkpoint_frequency, |
| @@ -752,6 +758,11 @@ def main(): | |||
| 752 | ) | 758 | ) |
| 753 | datamodule.setup() | 759 | datamodule.setup() |
| 754 | 760 | ||
| 761 | num_train_epochs = args.num_train_epochs | ||
| 762 | |||
| 763 | if num_train_epochs is None: | ||
| 764 | num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) | ||
| 765 | |||
| 755 | optimizer = create_optimizer( | 766 | optimizer = create_optimizer( |
| 756 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 767 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 757 | lr=args.learning_rate, | 768 | lr=args.learning_rate, |
| @@ -769,7 +780,7 @@ def main(): | |||
| 769 | annealing_exp=args.lr_annealing_exp, | 780 | annealing_exp=args.lr_annealing_exp, |
| 770 | cycles=args.lr_cycles, | 781 | cycles=args.lr_cycles, |
| 771 | end_lr=1e3, | 782 | end_lr=1e3, |
| 772 | train_epochs=args.num_train_epochs, | 783 | train_epochs=num_train_epochs, |
| 773 | warmup_epochs=args.lr_warmup_epochs, | 784 | warmup_epochs=args.lr_warmup_epochs, |
| 774 | ) | 785 | ) |
| 775 | 786 | ||
| @@ -779,6 +790,7 @@ def main(): | |||
| 779 | val_dataloader=datamodule.val_dataloader, | 790 | val_dataloader=datamodule.val_dataloader, |
| 780 | optimizer=optimizer, | 791 | optimizer=optimizer, |
| 781 | lr_scheduler=lr_scheduler, | 792 | lr_scheduler=lr_scheduler, |
| 793 | num_train_epochs=num_train_epochs, | ||
| 782 | # -- | 794 | # -- |
| 783 | sample_output_dir=sample_output_dir, | 795 | sample_output_dir=sample_output_dir, |
| 784 | placeholder_tokens=placeholder_tokens, | 796 | placeholder_tokens=placeholder_tokens, |
