diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 62 |
1 files changed, 22 insertions, 40 deletions
diff --git a/train_ti.py b/train_ti.py index 2497519..48a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
| 19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| @@ -446,15 +446,15 @@ def parse_args(): | |||
| 446 | if isinstance(args.placeholder_tokens, str): | 446 | if isinstance(args.placeholder_tokens, str): |
| 447 | args.placeholder_tokens = [args.placeholder_tokens] | 447 | args.placeholder_tokens = [args.placeholder_tokens] |
| 448 | 448 | ||
| 449 | if len(args.placeholder_tokens) == 0: | ||
| 450 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] | ||
| 451 | |||
| 452 | if isinstance(args.initializer_tokens, str): | 449 | if isinstance(args.initializer_tokens, str): |
| 453 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 450 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 454 | 451 | ||
| 455 | if len(args.initializer_tokens) == 0: | 452 | if len(args.initializer_tokens) == 0: |
| 456 | raise ValueError("You must specify --initializer_tokens") | 453 | raise ValueError("You must specify --initializer_tokens") |
| 457 | 454 | ||
| 455 | if len(args.placeholder_tokens) == 0: | ||
| 456 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
| 457 | |||
| 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 460 | 460 | ||
| @@ -544,9 +544,6 @@ def main(): | |||
| 544 | args.train_batch_size * accelerator.num_processes | 544 | args.train_batch_size * accelerator.num_processes |
| 545 | ) | 545 | ) |
| 546 | 546 | ||
| 547 | if args.find_lr: | ||
| 548 | args.learning_rate = 1e-5 | ||
| 549 | |||
| 550 | if args.use_8bit_adam: | 547 | if args.use_8bit_adam: |
| 551 | try: | 548 | try: |
| 552 | import bitsandbytes as bnb | 549 | import bitsandbytes as bnb |
| @@ -563,19 +560,6 @@ def main(): | |||
| 563 | elif args.mixed_precision == "bf16": | 560 | elif args.mixed_precision == "bf16": |
| 564 | weight_dtype = torch.bfloat16 | 561 | weight_dtype = torch.bfloat16 |
| 565 | 562 | ||
| 566 | def keyword_filter(item: VlpnDataItem): | ||
| 567 | cond1 = any( | ||
| 568 | keyword in part | ||
| 569 | for keyword in args.placeholder_tokens | ||
| 570 | for part in item.prompt | ||
| 571 | ) | ||
| 572 | cond3 = args.collection is None or args.collection in item.collection | ||
| 573 | cond4 = args.exclude_collections is None or not any( | ||
| 574 | collection in item.collection | ||
| 575 | for collection in args.exclude_collections | ||
| 576 | ) | ||
| 577 | return cond1 and cond3 and cond4 | ||
| 578 | |||
| 579 | datamodule = VlpnDataModule( | 563 | datamodule = VlpnDataModule( |
| 580 | data_file=args.train_data_file, | 564 | data_file=args.train_data_file, |
| 581 | batch_size=args.train_batch_size, | 565 | batch_size=args.train_batch_size, |
| @@ -593,7 +577,7 @@ def main(): | |||
| 593 | valid_set_size=args.valid_set_size, | 577 | valid_set_size=args.valid_set_size, |
| 594 | valid_set_repeat=args.valid_set_repeat, | 578 | valid_set_repeat=args.valid_set_repeat, |
| 595 | seed=args.seed, | 579 | seed=args.seed, |
| 596 | filter=keyword_filter, | 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 597 | dtype=weight_dtype | 581 | dtype=weight_dtype |
| 598 | ) | 582 | ) |
| 599 | datamodule.setup() | 583 | datamodule.setup() |
| @@ -622,8 +606,6 @@ def main(): | |||
| 622 | text_encoder=text_encoder, | 606 | text_encoder=text_encoder, |
| 623 | vae=vae, | 607 | vae=vae, |
| 624 | noise_scheduler=noise_scheduler, | 608 | noise_scheduler=noise_scheduler, |
| 625 | train_dataloader=train_dataloader, | ||
| 626 | val_dataloader=val_dataloader, | ||
| 627 | dtype=weight_dtype, | 609 | dtype=weight_dtype, |
| 628 | seed=args.seed, | 610 | seed=args.seed, |
| 629 | callbacks_fn=textual_inversion_strategy | 611 | callbacks_fn=textual_inversion_strategy |
| @@ -638,25 +620,25 @@ def main(): | |||
| 638 | amsgrad=args.adam_amsgrad, | 620 | amsgrad=args.adam_amsgrad, |
| 639 | ) | 621 | ) |
| 640 | 622 | ||
| 641 | if args.find_lr: | 623 | lr_scheduler = get_scheduler( |
| 642 | lr_scheduler = None | 624 | args.lr_scheduler, |
| 643 | else: | 625 | optimizer=optimizer, |
| 644 | lr_scheduler = get_scheduler( | 626 | num_training_steps_per_epoch=len(train_dataloader), |
| 645 | args.lr_scheduler, | 627 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 646 | optimizer=optimizer, | 628 | min_lr=args.lr_min_lr, |
| 647 | num_training_steps_per_epoch=len(train_dataloader), | 629 | warmup_func=args.lr_warmup_func, |
| 648 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 630 | annealing_func=args.lr_annealing_func, |
| 649 | min_lr=args.lr_min_lr, | 631 | warmup_exp=args.lr_warmup_exp, |
| 650 | warmup_func=args.lr_warmup_func, | 632 | annealing_exp=args.lr_annealing_exp, |
| 651 | annealing_func=args.lr_annealing_func, | 633 | cycles=args.lr_cycles, |
| 652 | warmup_exp=args.lr_warmup_exp, | 634 | train_epochs=args.num_train_epochs, |
| 653 | annealing_exp=args.lr_annealing_exp, | 635 | warmup_epochs=args.lr_warmup_epochs, |
| 654 | cycles=args.lr_cycles, | 636 | ) |
| 655 | train_epochs=args.num_train_epochs, | ||
| 656 | warmup_epochs=args.lr_warmup_epochs, | ||
| 657 | ) | ||
| 658 | 637 | ||
| 659 | trainer( | 638 | trainer( |
| 639 | project="textual_inversion", | ||
| 640 | train_dataloader=train_dataloader, | ||
| 641 | val_dataloader=val_dataloader, | ||
| 660 | optimizer=optimizer, | 642 | optimizer=optimizer, |
| 661 | lr_scheduler=lr_scheduler, | 643 | lr_scheduler=lr_scheduler, |
| 662 | num_train_epochs=args.num_train_epochs, | 644 | num_train_epochs=args.num_train_epochs, |
