diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 62 |
1 files changed, 22 insertions, 40 deletions
diff --git a/train_ti.py b/train_ti.py index 2497519..48a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed | |||
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
@@ -446,15 +446,15 @@ def parse_args(): | |||
446 | if isinstance(args.placeholder_tokens, str): | 446 | if isinstance(args.placeholder_tokens, str): |
447 | args.placeholder_tokens = [args.placeholder_tokens] | 447 | args.placeholder_tokens = [args.placeholder_tokens] |
448 | 448 | ||
449 | if len(args.placeholder_tokens) == 0: | ||
450 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] | ||
451 | |||
452 | if isinstance(args.initializer_tokens, str): | 449 | if isinstance(args.initializer_tokens, str): |
453 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 450 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
454 | 451 | ||
455 | if len(args.initializer_tokens) == 0: | 452 | if len(args.initializer_tokens) == 0: |
456 | raise ValueError("You must specify --initializer_tokens") | 453 | raise ValueError("You must specify --initializer_tokens") |
457 | 454 | ||
455 | if len(args.placeholder_tokens) == 0: | ||
456 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
457 | |||
458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
460 | 460 | ||
@@ -544,9 +544,6 @@ def main(): | |||
544 | args.train_batch_size * accelerator.num_processes | 544 | args.train_batch_size * accelerator.num_processes |
545 | ) | 545 | ) |
546 | 546 | ||
547 | if args.find_lr: | ||
548 | args.learning_rate = 1e-5 | ||
549 | |||
550 | if args.use_8bit_adam: | 547 | if args.use_8bit_adam: |
551 | try: | 548 | try: |
552 | import bitsandbytes as bnb | 549 | import bitsandbytes as bnb |
@@ -563,19 +560,6 @@ def main(): | |||
563 | elif args.mixed_precision == "bf16": | 560 | elif args.mixed_precision == "bf16": |
564 | weight_dtype = torch.bfloat16 | 561 | weight_dtype = torch.bfloat16 |
565 | 562 | ||
566 | def keyword_filter(item: VlpnDataItem): | ||
567 | cond1 = any( | ||
568 | keyword in part | ||
569 | for keyword in args.placeholder_tokens | ||
570 | for part in item.prompt | ||
571 | ) | ||
572 | cond3 = args.collection is None or args.collection in item.collection | ||
573 | cond4 = args.exclude_collections is None or not any( | ||
574 | collection in item.collection | ||
575 | for collection in args.exclude_collections | ||
576 | ) | ||
577 | return cond1 and cond3 and cond4 | ||
578 | |||
579 | datamodule = VlpnDataModule( | 563 | datamodule = VlpnDataModule( |
580 | data_file=args.train_data_file, | 564 | data_file=args.train_data_file, |
581 | batch_size=args.train_batch_size, | 565 | batch_size=args.train_batch_size, |
@@ -593,7 +577,7 @@ def main(): | |||
593 | valid_set_size=args.valid_set_size, | 577 | valid_set_size=args.valid_set_size, |
594 | valid_set_repeat=args.valid_set_repeat, | 578 | valid_set_repeat=args.valid_set_repeat, |
595 | seed=args.seed, | 579 | seed=args.seed, |
596 | filter=keyword_filter, | 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
597 | dtype=weight_dtype | 581 | dtype=weight_dtype |
598 | ) | 582 | ) |
599 | datamodule.setup() | 583 | datamodule.setup() |
@@ -622,8 +606,6 @@ def main(): | |||
622 | text_encoder=text_encoder, | 606 | text_encoder=text_encoder, |
623 | vae=vae, | 607 | vae=vae, |
624 | noise_scheduler=noise_scheduler, | 608 | noise_scheduler=noise_scheduler, |
625 | train_dataloader=train_dataloader, | ||
626 | val_dataloader=val_dataloader, | ||
627 | dtype=weight_dtype, | 609 | dtype=weight_dtype, |
628 | seed=args.seed, | 610 | seed=args.seed, |
629 | callbacks_fn=textual_inversion_strategy | 611 | callbacks_fn=textual_inversion_strategy |
@@ -638,25 +620,25 @@ def main(): | |||
638 | amsgrad=args.adam_amsgrad, | 620 | amsgrad=args.adam_amsgrad, |
639 | ) | 621 | ) |
640 | 622 | ||
641 | if args.find_lr: | 623 | lr_scheduler = get_scheduler( |
642 | lr_scheduler = None | 624 | args.lr_scheduler, |
643 | else: | 625 | optimizer=optimizer, |
644 | lr_scheduler = get_scheduler( | 626 | num_training_steps_per_epoch=len(train_dataloader), |
645 | args.lr_scheduler, | 627 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
646 | optimizer=optimizer, | 628 | min_lr=args.lr_min_lr, |
647 | num_training_steps_per_epoch=len(train_dataloader), | 629 | warmup_func=args.lr_warmup_func, |
648 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 630 | annealing_func=args.lr_annealing_func, |
649 | min_lr=args.lr_min_lr, | 631 | warmup_exp=args.lr_warmup_exp, |
650 | warmup_func=args.lr_warmup_func, | 632 | annealing_exp=args.lr_annealing_exp, |
651 | annealing_func=args.lr_annealing_func, | 633 | cycles=args.lr_cycles, |
652 | warmup_exp=args.lr_warmup_exp, | 634 | train_epochs=args.num_train_epochs, |
653 | annealing_exp=args.lr_annealing_exp, | 635 | warmup_epochs=args.lr_warmup_epochs, |
654 | cycles=args.lr_cycles, | 636 | ) |
655 | train_epochs=args.num_train_epochs, | ||
656 | warmup_epochs=args.lr_warmup_epochs, | ||
657 | ) | ||
658 | 637 | ||
659 | trainer( | 638 | trainer( |
639 | project="textual_inversion", | ||
640 | train_dataloader=train_dataloader, | ||
641 | val_dataloader=val_dataloader, | ||
660 | optimizer=optimizer, | 642 | optimizer=optimizer, |
661 | lr_scheduler=lr_scheduler, | 643 | lr_scheduler=lr_scheduler, |
662 | num_train_epochs=args.num_train_epochs, | 644 | num_train_epochs=args.num_train_epochs, |