summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
commit5821523a524190490a287c5e2aacb6e72cc3a4cf (patch)
treec0eac536c754f078683be6d59893ad23d70baf51 /train_ti.py
parentTraining update (diff)
downloadtextual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py113
1 files changed, 64 insertions, 49 deletions
diff --git a/train_ti.py b/train_ti.py
index e7aeb23..0891c49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -14,7 +14,7 @@ from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, add_placeholder_tokens, get_models
18from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 19from training.optimization import get_scheduler
20from training.util import save_args 20from training.util import save_args
@@ -79,6 +79,10 @@ def parse_args():
79 help="Number of vectors per embedding." 79 help="Number of vectors per embedding."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--simultaneous",
83 action="store_true",
84 )
85 parser.add_argument(
82 "--num_class_images", 86 "--num_class_images",
83 type=int, 87 type=int,
84 default=0, 88 default=0,
@@ -474,11 +478,12 @@ def parse_args():
474 if len(args.placeholder_tokens) != len(args.num_vectors): 478 if len(args.placeholder_tokens) != len(args.num_vectors):
475 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 479 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
476 480
477 if isinstance(args.train_data_template, str): 481 if not args.simultaneous:
478 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 482 if isinstance(args.train_data_template, str):
483 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
479 484
480 if len(args.placeholder_tokens) != len(args.train_data_template): 485 if len(args.placeholder_tokens) != len(args.train_data_template):
481 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 486 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
482 487
483 if isinstance(args.collection, str): 488 if isinstance(args.collection, str):
484 args.collection = [args.collection] 489 args.collection = [args.collection]
@@ -560,6 +565,8 @@ def main():
560 elif args.mixed_precision == "bf16": 565 elif args.mixed_precision == "bf16":
561 weight_dtype = torch.bfloat16 566 weight_dtype = torch.bfloat16
562 567
568 checkpoint_output_dir = output_dir.joinpath("checkpoints")
569
563 trainer = partial( 570 trainer = partial(
564 train, 571 train,
565 accelerator=accelerator, 572 accelerator=accelerator,
@@ -569,30 +576,50 @@ def main():
569 noise_scheduler=noise_scheduler, 576 noise_scheduler=noise_scheduler,
570 dtype=weight_dtype, 577 dtype=weight_dtype,
571 seed=args.seed, 578 seed=args.seed,
572 callbacks_fn=textual_inversion_strategy 579 with_prior_preservation=args.num_class_images != 0,
573 ) 580 prior_loss_weight=args.prior_loss_weight,
574 581 strategy=textual_inversion_strategy,
575 checkpoint_output_dir = output_dir.joinpath("checkpoints") 582 num_train_epochs=args.num_train_epochs,
576 583 sample_frequency=args.sample_frequency,
577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 584 checkpoint_frequency=args.checkpoint_frequency,
578 range(len(args.placeholder_tokens)), 585 global_step_offset=global_step_offset,
579 args.placeholder_tokens, 586 # --
580 args.initializer_tokens, 587 tokenizer=tokenizer,
581 args.num_vectors, 588 sample_scheduler=sample_scheduler,
582 args.train_data_template 589 checkpoint_output_dir=checkpoint_output_dir,
583 ): 590 learning_rate=args.learning_rate,
584 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") 591 gradient_checkpointing=args.gradient_checkpointing,
592 use_emb_decay=args.use_emb_decay,
593 emb_decay_target=args.emb_decay_target,
594 emb_decay_factor=args.emb_decay_factor,
595 emb_decay_start=args.emb_decay_start,
596 use_ema=args.use_ema,
597 ema_inv_gamma=args.ema_inv_gamma,
598 ema_power=args.ema_power,
599 ema_max_decay=args.ema_max_decay,
600 sample_batch_size=args.sample_batch_size,
601 sample_num_batches=args.sample_batches,
602 sample_num_steps=args.sample_steps,
603 sample_image_size=args.sample_image_size,
604 )
605
606 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
607 if len(placeholder_tokens) == 1:
608 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}")
609 else:
610 sample_output_dir = output_dir.joinpath("samples")
585 611
586 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 612 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
587 tokenizer=tokenizer, 613 tokenizer=tokenizer,
588 embeddings=embeddings, 614 embeddings=embeddings,
589 placeholder_tokens=[placeholder_token], 615 placeholder_tokens=placeholder_tokens,
590 initializer_tokens=[initializer_token], 616 initializer_tokens=initializer_tokens,
591 num_vectors=[num_vectors] 617 num_vectors=num_vectors
592 ) 618 )
593 619
594 print( 620 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
595 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") 621
622 print(f"{i + 1}: {stats})")
596 623
597 datamodule = VlpnDataModule( 624 datamodule = VlpnDataModule(
598 data_file=args.train_data_file, 625 data_file=args.train_data_file,
@@ -612,7 +639,7 @@ def main():
612 train_set_pad=args.train_set_pad, 639 train_set_pad=args.train_set_pad,
613 valid_set_pad=args.valid_set_pad, 640 valid_set_pad=args.valid_set_pad,
614 seed=args.seed, 641 seed=args.seed,
615 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), 642 filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections),
616 dtype=weight_dtype 643 dtype=weight_dtype
617 ) 644 )
618 datamodule.setup() 645 datamodule.setup()
@@ -647,36 +674,24 @@ def main():
647 val_dataloader=datamodule.val_dataloader, 674 val_dataloader=datamodule.val_dataloader,
648 optimizer=optimizer, 675 optimizer=optimizer,
649 lr_scheduler=lr_scheduler, 676 lr_scheduler=lr_scheduler,
650 num_train_epochs=args.num_train_epochs,
651 sample_frequency=args.sample_frequency,
652 checkpoint_frequency=args.checkpoint_frequency,
653 global_step_offset=global_step_offset,
654 with_prior_preservation=args.num_class_images != 0,
655 prior_loss_weight=args.prior_loss_weight,
656 # -- 677 # --
657 tokenizer=tokenizer,
658 sample_scheduler=sample_scheduler,
659 sample_output_dir=sample_output_dir, 678 sample_output_dir=sample_output_dir,
660 checkpoint_output_dir=checkpoint_output_dir, 679 placeholder_tokens=placeholder_tokens,
661 placeholder_tokens=[placeholder_token],
662 placeholder_token_ids=placeholder_token_ids, 680 placeholder_token_ids=placeholder_token_ids,
663 learning_rate=args.learning_rate,
664 gradient_checkpointing=args.gradient_checkpointing,
665 use_emb_decay=args.use_emb_decay,
666 emb_decay_target=args.emb_decay_target,
667 emb_decay_factor=args.emb_decay_factor,
668 emb_decay_start=args.emb_decay_start,
669 use_ema=args.use_ema,
670 ema_inv_gamma=args.ema_inv_gamma,
671 ema_power=args.ema_power,
672 ema_max_decay=args.ema_max_decay,
673 sample_batch_size=args.sample_batch_size,
674 sample_num_batches=args.sample_batches,
675 sample_num_steps=args.sample_steps,
676 sample_image_size=args.sample_image_size,
677 ) 681 )
678 682
679 embeddings.persist() 683 if args.simultaneous:
684 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
685 else:
686 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
687 range(len(args.placeholder_tokens)),
688 args.placeholder_tokens,
689 args.initializer_tokens,
690 args.num_vectors,
691 args.train_data_template
692 ):
693 run(i, [placeholder_token], [initializer_token], [num_vectors], data_template)
694 embeddings.persist()
680 695
681 696
682if __name__ == "__main__": 697if __name__ == "__main__":