diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:25:13 +0100 |
commit | e2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch) | |
tree | 574f7a794feab13e1cf0ed18522a66d4737b6db3 /train_dreambooth.py | |
parent | Unified training script structure (diff) | |
download | textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2 textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip |
Cleanup
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a1802a0..c180170 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -195,15 +195,6 @@ def parse_args(): | |||
195 | ), | 195 | ), |
196 | ) | 196 | ) |
197 | parser.add_argument( | 197 | parser.add_argument( |
198 | "--dataloader_num_workers", | ||
199 | type=int, | ||
200 | default=0, | ||
201 | help=( | ||
202 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
203 | " process." | ||
204 | ), | ||
205 | ) | ||
206 | parser.add_argument( | ||
207 | "--num_train_epochs", | 198 | "--num_train_epochs", |
208 | type=int, | 199 | type=int, |
209 | default=100 | 200 | default=100 |
@@ -577,24 +568,24 @@ def main(): | |||
577 | ) | 568 | ) |
578 | 569 | ||
579 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 570 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
580 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 571 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
581 | basepath.mkdir(parents=True, exist_ok=True) | 572 | output_dir.mkdir(parents=True, exist_ok=True) |
582 | 573 | ||
583 | accelerator = Accelerator( | 574 | accelerator = Accelerator( |
584 | log_with=LoggerType.TENSORBOARD, | 575 | log_with=LoggerType.TENSORBOARD, |
585 | logging_dir=f"{basepath}", | 576 | logging_dir=f"{output_dir}", |
586 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 577 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
587 | mixed_precision=args.mixed_precision | 578 | mixed_precision=args.mixed_precision |
588 | ) | 579 | ) |
589 | 580 | ||
590 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 581 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
591 | 582 | ||
592 | if args.seed is None: | 583 | if args.seed is None: |
593 | args.seed = torch.random.seed() >> 32 | 584 | args.seed = torch.random.seed() >> 32 |
594 | 585 | ||
595 | set_seed(args.seed) | 586 | set_seed(args.seed) |
596 | 587 | ||
597 | save_args(basepath, args) | 588 | save_args(output_dir, args) |
598 | 589 | ||
599 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 590 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
600 | args.pretrained_model_name_or_path) | 591 | args.pretrained_model_name_or_path) |
@@ -618,7 +609,7 @@ def main(): | |||
618 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
619 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
620 | 611 | ||
621 | placeholder_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
622 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
623 | embeddings=embeddings, | 614 | embeddings=embeddings, |
624 | placeholder_tokens=args.placeholder_tokens, | 615 | placeholder_tokens=args.placeholder_tokens, |
@@ -627,7 +618,9 @@ def main(): | |||
627 | ) | 618 | ) |
628 | 619 | ||
629 | if len(placeholder_token_ids) != 0: | 620 | if len(placeholder_token_ids) != 0: |
630 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 621 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
622 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
623 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
631 | 624 | ||
632 | if args.use_ema: | 625 | if args.use_ema: |
633 | ema_unet = EMAModel( | 626 | ema_unet = EMAModel( |
@@ -726,7 +719,6 @@ def main(): | |||
726 | template_key=args.train_data_template, | 719 | template_key=args.train_data_template, |
727 | valid_set_size=args.valid_set_size, | 720 | valid_set_size=args.valid_set_size, |
728 | valid_set_repeat=args.valid_set_repeat, | 721 | valid_set_repeat=args.valid_set_repeat, |
729 | num_workers=args.dataloader_num_workers, | ||
730 | seed=args.seed, | 722 | seed=args.seed, |
731 | filter=keyword_filter, | 723 | filter=keyword_filter, |
732 | dtype=weight_dtype | 724 | dtype=weight_dtype |
@@ -830,7 +822,6 @@ def main(): | |||
830 | noise_scheduler, | 822 | noise_scheduler, |
831 | unet, | 823 | unet, |
832 | text_encoder, | 824 | text_encoder, |
833 | args.num_class_images, | ||
834 | args.prior_loss_weight, | 825 | args.prior_loss_weight, |
835 | args.seed, | 826 | args.seed, |
836 | ) | 827 | ) |
@@ -848,7 +839,8 @@ def main(): | |||
848 | scheduler=sample_scheduler, | 839 | scheduler=sample_scheduler, |
849 | placeholder_tokens=args.placeholder_tokens, | 840 | placeholder_tokens=args.placeholder_tokens, |
850 | placeholder_token_ids=placeholder_token_ids, | 841 | placeholder_token_ids=placeholder_token_ids, |
851 | output_dir=basepath, | 842 | output_dir=output_dir, |
843 | sample_steps=args.sample_steps, | ||
852 | sample_image_size=args.sample_image_size, | 844 | sample_image_size=args.sample_image_size, |
853 | sample_batch_size=args.sample_batch_size, | 845 | sample_batch_size=args.sample_batch_size, |
854 | sample_batches=args.sample_batches, | 846 | sample_batches=args.sample_batches, |
@@ -873,7 +865,7 @@ def main(): | |||
873 | ) | 865 | ) |
874 | lr_finder.run(num_epochs=100, end_lr=1e2) | 866 | lr_finder.run(num_epochs=100, end_lr=1e2) |
875 | 867 | ||
876 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 868 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
877 | plt.close() | 869 | plt.close() |
878 | else: | 870 | else: |
879 | train_loop( | 871 | train_loop( |
@@ -886,7 +878,6 @@ def main(): | |||
886 | val_dataloader=val_dataloader, | 878 | val_dataloader=val_dataloader, |
887 | loss_step=loss_step_, | 879 | loss_step=loss_step_, |
888 | sample_frequency=args.sample_frequency, | 880 | sample_frequency=args.sample_frequency, |
889 | sample_steps=args.sample_steps, | ||
890 | checkpoint_frequency=args.checkpoint_frequency, | 881 | checkpoint_frequency=args.checkpoint_frequency, |
891 | global_step_offset=0, | 882 | global_step_offset=0, |
892 | num_epochs=args.num_train_epochs, | 883 | num_epochs=args.num_train_epochs, |