diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a1802a0..c180170 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -195,15 +195,6 @@ def parse_args(): | |||
| 195 | ), | 195 | ), |
| 196 | ) | 196 | ) |
| 197 | parser.add_argument( | 197 | parser.add_argument( |
| 198 | "--dataloader_num_workers", | ||
| 199 | type=int, | ||
| 200 | default=0, | ||
| 201 | help=( | ||
| 202 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 203 | " process." | ||
| 204 | ), | ||
| 205 | ) | ||
| 206 | parser.add_argument( | ||
| 207 | "--num_train_epochs", | 198 | "--num_train_epochs", |
| 208 | type=int, | 199 | type=int, |
| 209 | default=100 | 200 | default=100 |
| @@ -577,24 +568,24 @@ def main(): | |||
| 577 | ) | 568 | ) |
| 578 | 569 | ||
| 579 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 570 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 580 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 571 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 581 | basepath.mkdir(parents=True, exist_ok=True) | 572 | output_dir.mkdir(parents=True, exist_ok=True) |
| 582 | 573 | ||
| 583 | accelerator = Accelerator( | 574 | accelerator = Accelerator( |
| 584 | log_with=LoggerType.TENSORBOARD, | 575 | log_with=LoggerType.TENSORBOARD, |
| 585 | logging_dir=f"{basepath}", | 576 | logging_dir=f"{output_dir}", |
| 586 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 577 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 587 | mixed_precision=args.mixed_precision | 578 | mixed_precision=args.mixed_precision |
| 588 | ) | 579 | ) |
| 589 | 580 | ||
| 590 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 581 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 591 | 582 | ||
| 592 | if args.seed is None: | 583 | if args.seed is None: |
| 593 | args.seed = torch.random.seed() >> 32 | 584 | args.seed = torch.random.seed() >> 32 |
| 594 | 585 | ||
| 595 | set_seed(args.seed) | 586 | set_seed(args.seed) |
| 596 | 587 | ||
| 597 | save_args(basepath, args) | 588 | save_args(output_dir, args) |
| 598 | 589 | ||
| 599 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 590 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 600 | args.pretrained_model_name_or_path) | 591 | args.pretrained_model_name_or_path) |
| @@ -618,7 +609,7 @@ def main(): | |||
| 618 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 619 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 620 | 611 | ||
| 621 | placeholder_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 622 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
| 623 | embeddings=embeddings, | 614 | embeddings=embeddings, |
| 624 | placeholder_tokens=args.placeholder_tokens, | 615 | placeholder_tokens=args.placeholder_tokens, |
| @@ -627,7 +618,9 @@ def main(): | |||
| 627 | ) | 618 | ) |
| 628 | 619 | ||
| 629 | if len(placeholder_token_ids) != 0: | 620 | if len(placeholder_token_ids) != 0: |
| 630 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 621 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
| 622 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 623 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 631 | 624 | ||
| 632 | if args.use_ema: | 625 | if args.use_ema: |
| 633 | ema_unet = EMAModel( | 626 | ema_unet = EMAModel( |
| @@ -726,7 +719,6 @@ def main(): | |||
| 726 | template_key=args.train_data_template, | 719 | template_key=args.train_data_template, |
| 727 | valid_set_size=args.valid_set_size, | 720 | valid_set_size=args.valid_set_size, |
| 728 | valid_set_repeat=args.valid_set_repeat, | 721 | valid_set_repeat=args.valid_set_repeat, |
| 729 | num_workers=args.dataloader_num_workers, | ||
| 730 | seed=args.seed, | 722 | seed=args.seed, |
| 731 | filter=keyword_filter, | 723 | filter=keyword_filter, |
| 732 | dtype=weight_dtype | 724 | dtype=weight_dtype |
| @@ -830,7 +822,6 @@ def main(): | |||
| 830 | noise_scheduler, | 822 | noise_scheduler, |
| 831 | unet, | 823 | unet, |
| 832 | text_encoder, | 824 | text_encoder, |
| 833 | args.num_class_images, | ||
| 834 | args.prior_loss_weight, | 825 | args.prior_loss_weight, |
| 835 | args.seed, | 826 | args.seed, |
| 836 | ) | 827 | ) |
| @@ -848,7 +839,8 @@ def main(): | |||
| 848 | scheduler=sample_scheduler, | 839 | scheduler=sample_scheduler, |
| 849 | placeholder_tokens=args.placeholder_tokens, | 840 | placeholder_tokens=args.placeholder_tokens, |
| 850 | placeholder_token_ids=placeholder_token_ids, | 841 | placeholder_token_ids=placeholder_token_ids, |
| 851 | output_dir=basepath, | 842 | output_dir=output_dir, |
| 843 | sample_steps=args.sample_steps, | ||
| 852 | sample_image_size=args.sample_image_size, | 844 | sample_image_size=args.sample_image_size, |
| 853 | sample_batch_size=args.sample_batch_size, | 845 | sample_batch_size=args.sample_batch_size, |
| 854 | sample_batches=args.sample_batches, | 846 | sample_batches=args.sample_batches, |
| @@ -873,7 +865,7 @@ def main(): | |||
| 873 | ) | 865 | ) |
| 874 | lr_finder.run(num_epochs=100, end_lr=1e2) | 866 | lr_finder.run(num_epochs=100, end_lr=1e2) |
| 875 | 867 | ||
| 876 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 868 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 877 | plt.close() | 869 | plt.close() |
| 878 | else: | 870 | else: |
| 879 | train_loop( | 871 | train_loop( |
| @@ -886,7 +878,6 @@ def main(): | |||
| 886 | val_dataloader=val_dataloader, | 878 | val_dataloader=val_dataloader, |
| 887 | loss_step=loss_step_, | 879 | loss_step=loss_step_, |
| 888 | sample_frequency=args.sample_frequency, | 880 | sample_frequency=args.sample_frequency, |
| 889 | sample_steps=args.sample_steps, | ||
| 890 | checkpoint_frequency=args.checkpoint_frequency, | 881 | checkpoint_frequency=args.checkpoint_frequency, |
| 891 | global_step_offset=0, | 882 | global_step_offset=0, |
| 892 | num_epochs=args.num_train_epochs, | 883 | num_epochs=args.num_train_epochs, |
