diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 221 |
1 files changed, 114 insertions, 107 deletions
diff --git a/train_ti.py b/train_ti.py index 7aecdef..adba8d4 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -51,6 +51,7 @@ def parse_args(): | |||
51 | parser.add_argument( | 51 | parser.add_argument( |
52 | "--train_data_template", | 52 | "--train_data_template", |
53 | type=str, | 53 | type=str, |
54 | nargs='*', | ||
54 | default="template", | 55 | default="template", |
55 | ) | 56 | ) |
56 | parser.add_argument( | 57 | parser.add_argument( |
@@ -468,11 +469,17 @@ def parse_args(): | |||
468 | args.num_vectors = 1 | 469 | args.num_vectors = 1 |
469 | 470 | ||
470 | if isinstance(args.num_vectors, int): | 471 | if isinstance(args.num_vectors, int): |
471 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | 472 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
472 | 473 | ||
473 | if len(args.placeholder_tokens) != len(args.num_vectors): | 474 | if len(args.placeholder_tokens) != len(args.num_vectors): |
474 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 475 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
475 | 476 | ||
477 | if isinstance(args.train_data_template, str): | ||
478 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | ||
479 | |||
480 | if len(args.placeholder_tokens) != len(args.train_data_template): | ||
481 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | ||
482 | |||
476 | if isinstance(args.collection, str): | 483 | if isinstance(args.collection, str): |
477 | args.collection = [args.collection] | 484 | args.collection = [args.collection] |
478 | 485 | ||
@@ -507,6 +514,8 @@ def main(): | |||
507 | 514 | ||
508 | set_seed(args.seed) | 515 | set_seed(args.seed) |
509 | 516 | ||
517 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
518 | |||
510 | save_args(output_dir, args) | 519 | save_args(output_dir, args) |
511 | 520 | ||
512 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 521 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -531,19 +540,6 @@ def main(): | |||
531 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 540 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
532 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 541 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
533 | 542 | ||
534 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
535 | tokenizer=tokenizer, | ||
536 | embeddings=embeddings, | ||
537 | placeholder_tokens=args.placeholder_tokens, | ||
538 | initializer_tokens=args.initializer_tokens, | ||
539 | num_vectors=args.num_vectors | ||
540 | ) | ||
541 | |||
542 | if len(placeholder_token_ids) != 0: | ||
543 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
544 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
545 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
546 | |||
547 | if args.scale_lr: | 543 | if args.scale_lr: |
548 | args.learning_rate = ( | 544 | args.learning_rate = ( |
549 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -566,43 +562,6 @@ def main(): | |||
566 | elif args.mixed_precision == "bf16": | 562 | elif args.mixed_precision == "bf16": |
567 | weight_dtype = torch.bfloat16 | 563 | weight_dtype = torch.bfloat16 |
568 | 564 | ||
569 | datamodule = VlpnDataModule( | ||
570 | data_file=args.train_data_file, | ||
571 | batch_size=args.train_batch_size, | ||
572 | tokenizer=tokenizer, | ||
573 | class_subdir=args.class_image_dir, | ||
574 | num_class_images=args.num_class_images, | ||
575 | size=args.resolution, | ||
576 | num_buckets=args.num_buckets, | ||
577 | progressive_buckets=args.progressive_buckets, | ||
578 | bucket_step_size=args.bucket_step_size, | ||
579 | bucket_max_pixels=args.bucket_max_pixels, | ||
580 | dropout=args.tag_dropout, | ||
581 | shuffle=not args.no_tag_shuffle, | ||
582 | template_key=args.train_data_template, | ||
583 | valid_set_size=args.valid_set_size, | ||
584 | train_set_pad=args.train_set_pad, | ||
585 | valid_set_pad=args.valid_set_pad, | ||
586 | seed=args.seed, | ||
587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
588 | dtype=weight_dtype | ||
589 | ) | ||
590 | datamodule.setup() | ||
591 | |||
592 | if args.num_class_images != 0: | ||
593 | generate_class_images( | ||
594 | accelerator, | ||
595 | text_encoder, | ||
596 | vae, | ||
597 | unet, | ||
598 | tokenizer, | ||
599 | sample_scheduler, | ||
600 | datamodule.train_dataset, | ||
601 | args.sample_batch_size, | ||
602 | args.sample_image_size, | ||
603 | args.sample_steps | ||
604 | ) | ||
605 | |||
606 | trainer = partial( | 565 | trainer = partial( |
607 | train, | 566 | train, |
608 | accelerator=accelerator, | 567 | accelerator=accelerator, |
@@ -615,63 +574,111 @@ def main(): | |||
615 | callbacks_fn=textual_inversion_strategy | 574 | callbacks_fn=textual_inversion_strategy |
616 | ) | 575 | ) |
617 | 576 | ||
618 | optimizer = optimizer_class( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
619 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 578 | range(len(args.placeholder_tokens)), |
620 | lr=args.learning_rate, | 579 | args.placeholder_tokens, |
621 | betas=(args.adam_beta1, args.adam_beta2), | 580 | args.initializer_tokens, |
622 | weight_decay=args.adam_weight_decay, | 581 | args.num_vectors, |
623 | eps=args.adam_epsilon, | 582 | args.train_data_template |
624 | amsgrad=args.adam_amsgrad, | 583 | ): |
625 | ) | 584 | cur_subdir = output_dir.joinpath(placeholder_token) |
585 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
586 | |||
587 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
588 | tokenizer=tokenizer, | ||
589 | embeddings=embeddings, | ||
590 | placeholder_tokens=[placeholder_token], | ||
591 | initializer_tokens=[initializer_token], | ||
592 | num_vectors=[num_vectors] | ||
593 | ) | ||
626 | 594 | ||
627 | lr_scheduler = get_scheduler( | 595 | print( |
628 | args.lr_scheduler, | 596 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") |
629 | optimizer=optimizer, | 597 | |
630 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 598 | args.seed = seed_generator.seed() |
631 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 599 | |
632 | min_lr=args.lr_min_lr, | 600 | datamodule = VlpnDataModule( |
633 | warmup_func=args.lr_warmup_func, | 601 | data_file=args.train_data_file, |
634 | annealing_func=args.lr_annealing_func, | 602 | batch_size=args.train_batch_size, |
635 | warmup_exp=args.lr_warmup_exp, | 603 | tokenizer=tokenizer, |
636 | annealing_exp=args.lr_annealing_exp, | 604 | class_subdir=args.class_image_dir, |
637 | cycles=args.lr_cycles, | 605 | num_class_images=args.num_class_images, |
638 | train_epochs=args.num_train_epochs, | 606 | size=args.resolution, |
639 | warmup_epochs=args.lr_warmup_epochs, | 607 | num_buckets=args.num_buckets, |
640 | ) | 608 | progressive_buckets=args.progressive_buckets, |
641 | 609 | bucket_step_size=args.bucket_step_size, | |
642 | trainer( | 610 | bucket_max_pixels=args.bucket_max_pixels, |
643 | project="textual_inversion", | 611 | dropout=args.tag_dropout, |
644 | train_dataloader=datamodule.train_dataloader, | 612 | shuffle=not args.no_tag_shuffle, |
645 | val_dataloader=datamodule.val_dataloader, | 613 | template_key=data_template, |
646 | optimizer=optimizer, | 614 | valid_set_size=args.valid_set_size, |
647 | lr_scheduler=lr_scheduler, | 615 | train_set_pad=args.train_set_pad, |
648 | num_train_epochs=args.num_train_epochs, | 616 | valid_set_pad=args.valid_set_pad, |
649 | sample_frequency=args.sample_frequency, | 617 | seed=args.seed, |
650 | checkpoint_frequency=args.checkpoint_frequency, | 618 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
651 | global_step_offset=global_step_offset, | 619 | dtype=weight_dtype |
652 | with_prior_preservation=args.num_class_images != 0, | 620 | ) |
653 | prior_loss_weight=args.prior_loss_weight, | 621 | datamodule.setup() |
654 | # -- | 622 | |
655 | tokenizer=tokenizer, | 623 | optimizer = optimizer_class( |
656 | sample_scheduler=sample_scheduler, | 624 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
657 | output_dir=output_dir, | 625 | lr=args.learning_rate, |
658 | placeholder_tokens=args.placeholder_tokens, | 626 | betas=(args.adam_beta1, args.adam_beta2), |
659 | placeholder_token_ids=placeholder_token_ids, | 627 | weight_decay=args.adam_weight_decay, |
660 | learning_rate=args.learning_rate, | 628 | eps=args.adam_epsilon, |
661 | gradient_checkpointing=args.gradient_checkpointing, | 629 | amsgrad=args.adam_amsgrad, |
662 | use_emb_decay=args.use_emb_decay, | 630 | ) |
663 | emb_decay_target=args.emb_decay_target, | 631 | |
664 | emb_decay_factor=args.emb_decay_factor, | 632 | lr_scheduler = get_scheduler( |
665 | emb_decay_start=args.emb_decay_start, | 633 | args.lr_scheduler, |
666 | use_ema=args.use_ema, | 634 | optimizer=optimizer, |
667 | ema_inv_gamma=args.ema_inv_gamma, | 635 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
668 | ema_power=args.ema_power, | 636 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
669 | ema_max_decay=args.ema_max_decay, | 637 | min_lr=args.lr_min_lr, |
670 | sample_batch_size=args.sample_batch_size, | 638 | warmup_func=args.lr_warmup_func, |
671 | sample_num_batches=args.sample_batches, | 639 | annealing_func=args.lr_annealing_func, |
672 | sample_num_steps=args.sample_steps, | 640 | warmup_exp=args.lr_warmup_exp, |
673 | sample_image_size=args.sample_image_size, | 641 | annealing_exp=args.lr_annealing_exp, |
674 | ) | 642 | cycles=args.lr_cycles, |
643 | train_epochs=args.num_train_epochs, | ||
644 | warmup_epochs=args.lr_warmup_epochs, | ||
645 | ) | ||
646 | |||
647 | trainer( | ||
648 | project="textual_inversion", | ||
649 | train_dataloader=datamodule.train_dataloader, | ||
650 | val_dataloader=datamodule.val_dataloader, | ||
651 | optimizer=optimizer, | ||
652 | lr_scheduler=lr_scheduler, | ||
653 | num_train_epochs=args.num_train_epochs, | ||
654 | sample_frequency=args.sample_frequency, | ||
655 | checkpoint_frequency=args.checkpoint_frequency, | ||
656 | global_step_offset=global_step_offset, | ||
657 | with_prior_preservation=args.num_class_images != 0, | ||
658 | prior_loss_weight=args.prior_loss_weight, | ||
659 | # -- | ||
660 | tokenizer=tokenizer, | ||
661 | sample_scheduler=sample_scheduler, | ||
662 | output_dir=cur_subdir, | ||
663 | placeholder_tokens=[placeholder_token], | ||
664 | placeholder_token_ids=placeholder_token_ids, | ||
665 | learning_rate=args.learning_rate, | ||
666 | gradient_checkpointing=args.gradient_checkpointing, | ||
667 | use_emb_decay=args.use_emb_decay, | ||
668 | emb_decay_target=args.emb_decay_target, | ||
669 | emb_decay_factor=args.emb_decay_factor, | ||
670 | emb_decay_start=args.emb_decay_start, | ||
671 | use_ema=args.use_ema, | ||
672 | ema_inv_gamma=args.ema_inv_gamma, | ||
673 | ema_power=args.ema_power, | ||
674 | ema_max_decay=args.ema_max_decay, | ||
675 | sample_batch_size=args.sample_batch_size, | ||
676 | sample_num_batches=args.sample_batches, | ||
677 | sample_num_steps=args.sample_steps, | ||
678 | sample_image_size=args.sample_image_size, | ||
679 | ) | ||
680 | |||
681 | embeddings.persist() | ||
675 | 682 | ||
676 | 683 | ||
677 | if __name__ == "__main__": | 684 | if __name__ == "__main__": |