diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /train_ti.py | |
parent | Unified training script structure (diff) | |
download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip |
Cleanup
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -181,15 +181,6 @@ def parse_args(): | |||
181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
182 | ) | 182 | ) |
183 | parser.add_argument( | 183 | parser.add_argument( |
184 | "--dataloader_num_workers", | ||
185 | type=int, | ||
186 | default=0, | ||
187 | help=( | ||
188 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
189 | " process." | ||
190 | ), | ||
191 | ) | ||
192 | parser.add_argument( | ||
193 | "--num_train_epochs", | 184 | "--num_train_epochs", |
194 | type=int, | 185 | type=int, |
195 | default=100 | 186 | default=100 |
@@ -575,24 +566,24 @@ def main(): | |||
575 | 566 | ||
576 | global_step_offset = args.global_step | 567 | global_step_offset = args.global_step |
577 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 568 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
578 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 569 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
579 | basepath.mkdir(parents=True, exist_ok=True) | 570 | output_dir.mkdir(parents=True, exist_ok=True) |
580 | 571 | ||
581 | accelerator = Accelerator( | 572 | accelerator = Accelerator( |
582 | log_with=LoggerType.TENSORBOARD, | 573 | log_with=LoggerType.TENSORBOARD, |
583 | logging_dir=f"{basepath}", | 574 | logging_dir=f"{output_dir}", |
584 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 575 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
585 | mixed_precision=args.mixed_precision | 576 | mixed_precision=args.mixed_precision |
586 | ) | 577 | ) |
587 | 578 | ||
588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 579 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
589 | 580 | ||
590 | if args.seed is None: | 581 | if args.seed is None: |
591 | args.seed = torch.random.seed() >> 32 | 582 | args.seed = torch.random.seed() >> 32 |
592 | 583 | ||
593 | set_seed(args.seed) | 584 | set_seed(args.seed) |
594 | 585 | ||
595 | save_args(basepath, args) | 586 | save_args(output_dir, args) |
596 | 587 | ||
597 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 588 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
598 | args.pretrained_model_name_or_path) | 589 | args.pretrained_model_name_or_path) |
@@ -616,7 +607,7 @@ def main(): | |||
616 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 607 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 608 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
618 | 609 | ||
619 | placeholder_token_ids = add_placeholder_tokens( | 610 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
620 | tokenizer=tokenizer, | 611 | tokenizer=tokenizer, |
621 | embeddings=embeddings, | 612 | embeddings=embeddings, |
622 | placeholder_tokens=args.placeholder_tokens, | 613 | placeholder_tokens=args.placeholder_tokens, |
@@ -625,7 +616,9 @@ def main(): | |||
625 | ) | 616 | ) |
626 | 617 | ||
627 | if len(placeholder_token_ids) != 0: | 618 | if len(placeholder_token_ids) != 0: |
628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 619 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
620 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
621 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
629 | 622 | ||
630 | if args.use_ema: | 623 | if args.use_ema: |
631 | ema_embeddings = EMAModel( | 624 | ema_embeddings = EMAModel( |
@@ -708,7 +701,6 @@ def main(): | |||
708 | template_key=args.train_data_template, | 701 | template_key=args.train_data_template, |
709 | valid_set_size=args.valid_set_size, | 702 | valid_set_size=args.valid_set_size, |
710 | valid_set_repeat=args.valid_set_repeat, | 703 | valid_set_repeat=args.valid_set_repeat, |
711 | num_workers=args.dataloader_num_workers, | ||
712 | seed=args.seed, | 704 | seed=args.seed, |
713 | filter=keyword_filter, | 705 | filter=keyword_filter, |
714 | dtype=weight_dtype | 706 | dtype=weight_dtype |
@@ -807,7 +799,6 @@ def main(): | |||
807 | noise_scheduler, | 799 | noise_scheduler, |
808 | unet, | 800 | unet, |
809 | text_encoder, | 801 | text_encoder, |
810 | args.num_class_images != 0, | ||
811 | args.prior_loss_weight, | 802 | args.prior_loss_weight, |
812 | args.seed, | 803 | args.seed, |
813 | ) | 804 | ) |
@@ -825,7 +816,8 @@ def main(): | |||
825 | scheduler=sample_scheduler, | 816 | scheduler=sample_scheduler, |
826 | placeholder_tokens=args.placeholder_tokens, | 817 | placeholder_tokens=args.placeholder_tokens, |
827 | placeholder_token_ids=placeholder_token_ids, | 818 | placeholder_token_ids=placeholder_token_ids, |
828 | output_dir=basepath, | 819 | output_dir=output_dir, |
820 | sample_steps=args.sample_steps, | ||
829 | sample_image_size=args.sample_image_size, | 821 | sample_image_size=args.sample_image_size, |
830 | sample_batch_size=args.sample_batch_size, | 822 | sample_batch_size=args.sample_batch_size, |
831 | sample_batches=args.sample_batches, | 823 | sample_batches=args.sample_batches, |
@@ -849,7 +841,7 @@ def main(): | |||
849 | ) | 841 | ) |
850 | lr_finder.run(num_epochs=100, end_lr=1e3) | 842 | lr_finder.run(num_epochs=100, end_lr=1e3) |
851 | 843 | ||
852 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 844 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
853 | plt.close() | 845 | plt.close() |
854 | else: | 846 | else: |
855 | train_loop( | 847 | train_loop( |
@@ -862,7 +854,6 @@ def main(): | |||
862 | val_dataloader=val_dataloader, | 854 | val_dataloader=val_dataloader, |
863 | loss_step=loss_step_, | 855 | loss_step=loss_step_, |
864 | sample_frequency=args.sample_frequency, | 856 | sample_frequency=args.sample_frequency, |
865 | sample_steps=args.sample_steps, | ||
866 | checkpoint_frequency=args.checkpoint_frequency, | 857 | checkpoint_frequency=args.checkpoint_frequency, |
867 | global_step_offset=global_step_offset, | 858 | global_step_offset=global_step_offset, |
868 | num_epochs=args.num_train_epochs, | 859 | num_epochs=args.num_train_epochs, |