diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
| commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
| tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /train_ti.py | |
| parent | Unified training script structure (diff) | |
| download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip | |
Cleanup
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -181,15 +181,6 @@ def parse_args(): | |||
| 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
| 182 | ) | 182 | ) |
| 183 | parser.add_argument( | 183 | parser.add_argument( |
| 184 | "--dataloader_num_workers", | ||
| 185 | type=int, | ||
| 186 | default=0, | ||
| 187 | help=( | ||
| 188 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 189 | " process." | ||
| 190 | ), | ||
| 191 | ) | ||
| 192 | parser.add_argument( | ||
| 193 | "--num_train_epochs", | 184 | "--num_train_epochs", |
| 194 | type=int, | 185 | type=int, |
| 195 | default=100 | 186 | default=100 |
| @@ -575,24 +566,24 @@ def main(): | |||
| 575 | 566 | ||
| 576 | global_step_offset = args.global_step | 567 | global_step_offset = args.global_step |
| 577 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 568 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 578 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 569 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 579 | basepath.mkdir(parents=True, exist_ok=True) | 570 | output_dir.mkdir(parents=True, exist_ok=True) |
| 580 | 571 | ||
| 581 | accelerator = Accelerator( | 572 | accelerator = Accelerator( |
| 582 | log_with=LoggerType.TENSORBOARD, | 573 | log_with=LoggerType.TENSORBOARD, |
| 583 | logging_dir=f"{basepath}", | 574 | logging_dir=f"{output_dir}", |
| 584 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 575 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 585 | mixed_precision=args.mixed_precision | 576 | mixed_precision=args.mixed_precision |
| 586 | ) | 577 | ) |
| 587 | 578 | ||
| 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 579 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 589 | 580 | ||
| 590 | if args.seed is None: | 581 | if args.seed is None: |
| 591 | args.seed = torch.random.seed() >> 32 | 582 | args.seed = torch.random.seed() >> 32 |
| 592 | 583 | ||
| 593 | set_seed(args.seed) | 584 | set_seed(args.seed) |
| 594 | 585 | ||
| 595 | save_args(basepath, args) | 586 | save_args(output_dir, args) |
| 596 | 587 | ||
| 597 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 588 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 598 | args.pretrained_model_name_or_path) | 589 | args.pretrained_model_name_or_path) |
| @@ -616,7 +607,7 @@ def main(): | |||
| 616 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 607 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 608 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 618 | 609 | ||
| 619 | placeholder_token_ids = add_placeholder_tokens( | 610 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 620 | tokenizer=tokenizer, | 611 | tokenizer=tokenizer, |
| 621 | embeddings=embeddings, | 612 | embeddings=embeddings, |
| 622 | placeholder_tokens=args.placeholder_tokens, | 613 | placeholder_tokens=args.placeholder_tokens, |
| @@ -625,7 +616,9 @@ def main(): | |||
| 625 | ) | 616 | ) |
| 626 | 617 | ||
| 627 | if len(placeholder_token_ids) != 0: | 618 | if len(placeholder_token_ids) != 0: |
| 628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 619 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
| 620 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 621 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 629 | 622 | ||
| 630 | if args.use_ema: | 623 | if args.use_ema: |
| 631 | ema_embeddings = EMAModel( | 624 | ema_embeddings = EMAModel( |
| @@ -708,7 +701,6 @@ def main(): | |||
| 708 | template_key=args.train_data_template, | 701 | template_key=args.train_data_template, |
| 709 | valid_set_size=args.valid_set_size, | 702 | valid_set_size=args.valid_set_size, |
| 710 | valid_set_repeat=args.valid_set_repeat, | 703 | valid_set_repeat=args.valid_set_repeat, |
| 711 | num_workers=args.dataloader_num_workers, | ||
| 712 | seed=args.seed, | 704 | seed=args.seed, |
| 713 | filter=keyword_filter, | 705 | filter=keyword_filter, |
| 714 | dtype=weight_dtype | 706 | dtype=weight_dtype |
| @@ -807,7 +799,6 @@ def main(): | |||
| 807 | noise_scheduler, | 799 | noise_scheduler, |
| 808 | unet, | 800 | unet, |
| 809 | text_encoder, | 801 | text_encoder, |
| 810 | args.num_class_images != 0, | ||
| 811 | args.prior_loss_weight, | 802 | args.prior_loss_weight, |
| 812 | args.seed, | 803 | args.seed, |
| 813 | ) | 804 | ) |
| @@ -825,7 +816,8 @@ def main(): | |||
| 825 | scheduler=sample_scheduler, | 816 | scheduler=sample_scheduler, |
| 826 | placeholder_tokens=args.placeholder_tokens, | 817 | placeholder_tokens=args.placeholder_tokens, |
| 827 | placeholder_token_ids=placeholder_token_ids, | 818 | placeholder_token_ids=placeholder_token_ids, |
| 828 | output_dir=basepath, | 819 | output_dir=output_dir, |
| 820 | sample_steps=args.sample_steps, | ||
| 829 | sample_image_size=args.sample_image_size, | 821 | sample_image_size=args.sample_image_size, |
| 830 | sample_batch_size=args.sample_batch_size, | 822 | sample_batch_size=args.sample_batch_size, |
| 831 | sample_batches=args.sample_batches, | 823 | sample_batches=args.sample_batches, |
| @@ -849,7 +841,7 @@ def main(): | |||
| 849 | ) | 841 | ) |
| 850 | lr_finder.run(num_epochs=100, end_lr=1e3) | 842 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 851 | 843 | ||
| 852 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 844 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 853 | plt.close() | 845 | plt.close() |
| 854 | else: | 846 | else: |
| 855 | train_loop( | 847 | train_loop( |
| @@ -862,7 +854,6 @@ def main(): | |||
| 862 | val_dataloader=val_dataloader, | 854 | val_dataloader=val_dataloader, |
| 863 | loss_step=loss_step_, | 855 | loss_step=loss_step_, |
| 864 | sample_frequency=args.sample_frequency, | 856 | sample_frequency=args.sample_frequency, |
| 865 | sample_steps=args.sample_steps, | ||
| 866 | checkpoint_frequency=args.checkpoint_frequency, | 857 | checkpoint_frequency=args.checkpoint_frequency, |
| 867 | global_step_offset=global_step_offset, | 858 | global_step_offset=global_step_offset, |
| 868 | num_epochs=args.num_train_epochs, | 859 | num_epochs=args.num_train_epochs, |
