summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py33
1 files changed, 12 insertions, 21 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a1802a0..c180170 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -195,15 +195,6 @@ def parse_args():
195 ), 195 ),
196 ) 196 )
197 parser.add_argument( 197 parser.add_argument(
198 "--dataloader_num_workers",
199 type=int,
200 default=0,
201 help=(
202 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
203 " process."
204 ),
205 )
206 parser.add_argument(
207 "--num_train_epochs", 198 "--num_train_epochs",
208 type=int, 199 type=int,
209 default=100 200 default=100
@@ -577,24 +568,24 @@ def main():
577 ) 568 )
578 569
579 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 570 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
580 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 571 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
581 basepath.mkdir(parents=True, exist_ok=True) 572 output_dir.mkdir(parents=True, exist_ok=True)
582 573
583 accelerator = Accelerator( 574 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD, 575 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}", 576 logging_dir=f"{output_dir}",
586 gradient_accumulation_steps=args.gradient_accumulation_steps, 577 gradient_accumulation_steps=args.gradient_accumulation_steps,
587 mixed_precision=args.mixed_precision 578 mixed_precision=args.mixed_precision
588 ) 579 )
589 580
590 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 581 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
591 582
592 if args.seed is None: 583 if args.seed is None:
593 args.seed = torch.random.seed() >> 32 584 args.seed = torch.random.seed() >> 32
594 585
595 set_seed(args.seed) 586 set_seed(args.seed)
596 587
597 save_args(basepath, args) 588 save_args(output_dir, args)
598 589
599 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 590 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
600 args.pretrained_model_name_or_path) 591 args.pretrained_model_name_or_path)
@@ -618,7 +609,7 @@ def main():
618 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
619 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
620 611
621 placeholder_token_ids = add_placeholder_tokens( 612 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
622 tokenizer=tokenizer, 613 tokenizer=tokenizer,
623 embeddings=embeddings, 614 embeddings=embeddings,
624 placeholder_tokens=args.placeholder_tokens, 615 placeholder_tokens=args.placeholder_tokens,
@@ -627,7 +618,9 @@ def main():
627 ) 618 )
628 619
629 if len(placeholder_token_ids) != 0: 620 if len(placeholder_token_ids) != 0:
630 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") 621 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
622 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
623 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
631 624
632 if args.use_ema: 625 if args.use_ema:
633 ema_unet = EMAModel( 626 ema_unet = EMAModel(
@@ -726,7 +719,6 @@ def main():
726 template_key=args.train_data_template, 719 template_key=args.train_data_template,
727 valid_set_size=args.valid_set_size, 720 valid_set_size=args.valid_set_size,
728 valid_set_repeat=args.valid_set_repeat, 721 valid_set_repeat=args.valid_set_repeat,
729 num_workers=args.dataloader_num_workers,
730 seed=args.seed, 722 seed=args.seed,
731 filter=keyword_filter, 723 filter=keyword_filter,
732 dtype=weight_dtype 724 dtype=weight_dtype
@@ -830,7 +822,6 @@ def main():
830 noise_scheduler, 822 noise_scheduler,
831 unet, 823 unet,
832 text_encoder, 824 text_encoder,
833 args.num_class_images,
834 args.prior_loss_weight, 825 args.prior_loss_weight,
835 args.seed, 826 args.seed,
836 ) 827 )
@@ -848,7 +839,8 @@ def main():
848 scheduler=sample_scheduler, 839 scheduler=sample_scheduler,
849 placeholder_tokens=args.placeholder_tokens, 840 placeholder_tokens=args.placeholder_tokens,
850 placeholder_token_ids=placeholder_token_ids, 841 placeholder_token_ids=placeholder_token_ids,
851 output_dir=basepath, 842 output_dir=output_dir,
843 sample_steps=args.sample_steps,
852 sample_image_size=args.sample_image_size, 844 sample_image_size=args.sample_image_size,
853 sample_batch_size=args.sample_batch_size, 845 sample_batch_size=args.sample_batch_size,
854 sample_batches=args.sample_batches, 846 sample_batches=args.sample_batches,
@@ -873,7 +865,7 @@ def main():
873 ) 865 )
874 lr_finder.run(num_epochs=100, end_lr=1e2) 866 lr_finder.run(num_epochs=100, end_lr=1e2)
875 867
876 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 868 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
877 plt.close() 869 plt.close()
878 else: 870 else:
879 train_loop( 871 train_loop(
@@ -886,7 +878,6 @@ def main():
886 val_dataloader=val_dataloader, 878 val_dataloader=val_dataloader,
887 loss_step=loss_step_, 879 loss_step=loss_step_,
888 sample_frequency=args.sample_frequency, 880 sample_frequency=args.sample_frequency,
889 sample_steps=args.sample_steps,
890 checkpoint_frequency=args.checkpoint_frequency, 881 checkpoint_frequency=args.checkpoint_frequency,
891 global_step_offset=0, 882 global_step_offset=0,
892 num_epochs=args.num_train_epochs, 883 num_epochs=args.num_train_epochs,