diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 113 |
1 files changed, 64 insertions, 49 deletions
diff --git a/train_ti.py b/train_ti.py index e7aeb23..0891c49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -14,7 +14,7 @@ from slugify import slugify | |||
14 | 14 | ||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
20 | from training.util import save_args | 20 | from training.util import save_args |
@@ -79,6 +79,10 @@ def parse_args(): | |||
79 | help="Number of vectors per embedding." | 79 | help="Number of vectors per embedding." |
80 | ) | 80 | ) |
81 | parser.add_argument( | 81 | parser.add_argument( |
82 | "--simultaneous", | ||
83 | action="store_true", | ||
84 | ) | ||
85 | parser.add_argument( | ||
82 | "--num_class_images", | 86 | "--num_class_images", |
83 | type=int, | 87 | type=int, |
84 | default=0, | 88 | default=0, |
@@ -474,11 +478,12 @@ def parse_args(): | |||
474 | if len(args.placeholder_tokens) != len(args.num_vectors): | 478 | if len(args.placeholder_tokens) != len(args.num_vectors): |
475 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 479 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
476 | 480 | ||
477 | if isinstance(args.train_data_template, str): | 481 | if not args.simultaneous: |
478 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 482 | if isinstance(args.train_data_template, str): |
483 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | ||
479 | 484 | ||
480 | if len(args.placeholder_tokens) != len(args.train_data_template): | 485 | if len(args.placeholder_tokens) != len(args.train_data_template): |
481 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
482 | 487 | ||
483 | if isinstance(args.collection, str): | 488 | if isinstance(args.collection, str): |
484 | args.collection = [args.collection] | 489 | args.collection = [args.collection] |
@@ -560,6 +565,8 @@ def main(): | |||
560 | elif args.mixed_precision == "bf16": | 565 | elif args.mixed_precision == "bf16": |
561 | weight_dtype = torch.bfloat16 | 566 | weight_dtype = torch.bfloat16 |
562 | 567 | ||
568 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | ||
569 | |||
563 | trainer = partial( | 570 | trainer = partial( |
564 | train, | 571 | train, |
565 | accelerator=accelerator, | 572 | accelerator=accelerator, |
@@ -569,30 +576,50 @@ def main(): | |||
569 | noise_scheduler=noise_scheduler, | 576 | noise_scheduler=noise_scheduler, |
570 | dtype=weight_dtype, | 577 | dtype=weight_dtype, |
571 | seed=args.seed, | 578 | seed=args.seed, |
572 | callbacks_fn=textual_inversion_strategy | 579 | with_prior_preservation=args.num_class_images != 0, |
573 | ) | 580 | prior_loss_weight=args.prior_loss_weight, |
574 | 581 | strategy=textual_inversion_strategy, | |
575 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 582 | num_train_epochs=args.num_train_epochs, |
576 | 583 | sample_frequency=args.sample_frequency, | |
577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 584 | checkpoint_frequency=args.checkpoint_frequency, |
578 | range(len(args.placeholder_tokens)), | 585 | global_step_offset=global_step_offset, |
579 | args.placeholder_tokens, | 586 | # -- |
580 | args.initializer_tokens, | 587 | tokenizer=tokenizer, |
581 | args.num_vectors, | 588 | sample_scheduler=sample_scheduler, |
582 | args.train_data_template | 589 | checkpoint_output_dir=checkpoint_output_dir, |
583 | ): | 590 | learning_rate=args.learning_rate, |
584 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") | 591 | gradient_checkpointing=args.gradient_checkpointing, |
592 | use_emb_decay=args.use_emb_decay, | ||
593 | emb_decay_target=args.emb_decay_target, | ||
594 | emb_decay_factor=args.emb_decay_factor, | ||
595 | emb_decay_start=args.emb_decay_start, | ||
596 | use_ema=args.use_ema, | ||
597 | ema_inv_gamma=args.ema_inv_gamma, | ||
598 | ema_power=args.ema_power, | ||
599 | ema_max_decay=args.ema_max_decay, | ||
600 | sample_batch_size=args.sample_batch_size, | ||
601 | sample_num_batches=args.sample_batches, | ||
602 | sample_num_steps=args.sample_steps, | ||
603 | sample_image_size=args.sample_image_size, | ||
604 | ) | ||
605 | |||
606 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | ||
607 | if len(placeholder_tokens) == 1: | ||
608 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") | ||
609 | else: | ||
610 | sample_output_dir = output_dir.joinpath("samples") | ||
585 | 611 | ||
586 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
587 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
588 | embeddings=embeddings, | 614 | embeddings=embeddings, |
589 | placeholder_tokens=[placeholder_token], | 615 | placeholder_tokens=placeholder_tokens, |
590 | initializer_tokens=[initializer_token], | 616 | initializer_tokens=initializer_tokens, |
591 | num_vectors=[num_vectors] | 617 | num_vectors=num_vectors |
592 | ) | 618 | ) |
593 | 619 | ||
594 | print( | 620 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
595 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | 621 | |
622 | print(f"{i + 1}: {stats})") | ||
596 | 623 | ||
597 | datamodule = VlpnDataModule( | 624 | datamodule = VlpnDataModule( |
598 | data_file=args.train_data_file, | 625 | data_file=args.train_data_file, |
@@ -612,7 +639,7 @@ def main(): | |||
612 | train_set_pad=args.train_set_pad, | 639 | train_set_pad=args.train_set_pad, |
613 | valid_set_pad=args.valid_set_pad, | 640 | valid_set_pad=args.valid_set_pad, |
614 | seed=args.seed, | 641 | seed=args.seed, |
615 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 642 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), |
616 | dtype=weight_dtype | 643 | dtype=weight_dtype |
617 | ) | 644 | ) |
618 | datamodule.setup() | 645 | datamodule.setup() |
@@ -647,36 +674,24 @@ def main(): | |||
647 | val_dataloader=datamodule.val_dataloader, | 674 | val_dataloader=datamodule.val_dataloader, |
648 | optimizer=optimizer, | 675 | optimizer=optimizer, |
649 | lr_scheduler=lr_scheduler, | 676 | lr_scheduler=lr_scheduler, |
650 | num_train_epochs=args.num_train_epochs, | ||
651 | sample_frequency=args.sample_frequency, | ||
652 | checkpoint_frequency=args.checkpoint_frequency, | ||
653 | global_step_offset=global_step_offset, | ||
654 | with_prior_preservation=args.num_class_images != 0, | ||
655 | prior_loss_weight=args.prior_loss_weight, | ||
656 | # -- | 677 | # -- |
657 | tokenizer=tokenizer, | ||
658 | sample_scheduler=sample_scheduler, | ||
659 | sample_output_dir=sample_output_dir, | 678 | sample_output_dir=sample_output_dir, |
660 | checkpoint_output_dir=checkpoint_output_dir, | 679 | placeholder_tokens=placeholder_tokens, |
661 | placeholder_tokens=[placeholder_token], | ||
662 | placeholder_token_ids=placeholder_token_ids, | 680 | placeholder_token_ids=placeholder_token_ids, |
663 | learning_rate=args.learning_rate, | ||
664 | gradient_checkpointing=args.gradient_checkpointing, | ||
665 | use_emb_decay=args.use_emb_decay, | ||
666 | emb_decay_target=args.emb_decay_target, | ||
667 | emb_decay_factor=args.emb_decay_factor, | ||
668 | emb_decay_start=args.emb_decay_start, | ||
669 | use_ema=args.use_ema, | ||
670 | ema_inv_gamma=args.ema_inv_gamma, | ||
671 | ema_power=args.ema_power, | ||
672 | ema_max_decay=args.ema_max_decay, | ||
673 | sample_batch_size=args.sample_batch_size, | ||
674 | sample_num_batches=args.sample_batches, | ||
675 | sample_num_steps=args.sample_steps, | ||
676 | sample_image_size=args.sample_image_size, | ||
677 | ) | 681 | ) |
678 | 682 | ||
679 | embeddings.persist() | 683 | if args.simultaneous: |
684 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | ||
685 | else: | ||
686 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
687 | range(len(args.placeholder_tokens)), | ||
688 | args.placeholder_tokens, | ||
689 | args.initializer_tokens, | ||
690 | args.num_vectors, | ||
691 | args.train_data_template | ||
692 | ): | ||
693 | run(i, [placeholder_token], [initializer_token], [num_vectors], data_template) | ||
694 | embeddings.persist() | ||
680 | 695 | ||
681 | 696 | ||
682 | if __name__ == "__main__": | 697 | if __name__ == "__main__": |