summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py33
1 files changed, 12 insertions, 21 deletions
diff --git a/train_ti.py b/train_ti.py
index d2ca7eb..d752927 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -181,15 +181,6 @@ def parse_args():
181 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 181 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
182 ) 182 )
183 parser.add_argument( 183 parser.add_argument(
184 "--dataloader_num_workers",
185 type=int,
186 default=0,
187 help=(
188 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
189 " process."
190 ),
191 )
192 parser.add_argument(
193 "--num_train_epochs", 184 "--num_train_epochs",
194 type=int, 185 type=int,
195 default=100 186 default=100
@@ -575,24 +566,24 @@ def main():
575 566
576 global_step_offset = args.global_step 567 global_step_offset = args.global_step
577 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 568 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
578 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 569 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
579 basepath.mkdir(parents=True, exist_ok=True) 570 output_dir.mkdir(parents=True, exist_ok=True)
580 571
581 accelerator = Accelerator( 572 accelerator = Accelerator(
582 log_with=LoggerType.TENSORBOARD, 573 log_with=LoggerType.TENSORBOARD,
583 logging_dir=f"{basepath}", 574 logging_dir=f"{output_dir}",
584 gradient_accumulation_steps=args.gradient_accumulation_steps, 575 gradient_accumulation_steps=args.gradient_accumulation_steps,
585 mixed_precision=args.mixed_precision 576 mixed_precision=args.mixed_precision
586 ) 577 )
587 578
588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 579 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
589 580
590 if args.seed is None: 581 if args.seed is None:
591 args.seed = torch.random.seed() >> 32 582 args.seed = torch.random.seed() >> 32
592 583
593 set_seed(args.seed) 584 set_seed(args.seed)
594 585
595 save_args(basepath, args) 586 save_args(output_dir, args)
596 587
597 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 588 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
598 args.pretrained_model_name_or_path) 589 args.pretrained_model_name_or_path)
@@ -616,7 +607,7 @@ def main():
616 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 607 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
617 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 608 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
618 609
619 placeholder_token_ids = add_placeholder_tokens( 610 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
620 tokenizer=tokenizer, 611 tokenizer=tokenizer,
621 embeddings=embeddings, 612 embeddings=embeddings,
622 placeholder_tokens=args.placeholder_tokens, 613 placeholder_tokens=args.placeholder_tokens,
@@ -625,7 +616,9 @@ def main():
625 ) 616 )
626 617
627 if len(placeholder_token_ids) != 0: 618 if len(placeholder_token_ids) != 0:
628 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") 619 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
620 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
621 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
629 622
630 if args.use_ema: 623 if args.use_ema:
631 ema_embeddings = EMAModel( 624 ema_embeddings = EMAModel(
@@ -708,7 +701,6 @@ def main():
708 template_key=args.train_data_template, 701 template_key=args.train_data_template,
709 valid_set_size=args.valid_set_size, 702 valid_set_size=args.valid_set_size,
710 valid_set_repeat=args.valid_set_repeat, 703 valid_set_repeat=args.valid_set_repeat,
711 num_workers=args.dataloader_num_workers,
712 seed=args.seed, 704 seed=args.seed,
713 filter=keyword_filter, 705 filter=keyword_filter,
714 dtype=weight_dtype 706 dtype=weight_dtype
@@ -807,7 +799,6 @@ def main():
807 noise_scheduler, 799 noise_scheduler,
808 unet, 800 unet,
809 text_encoder, 801 text_encoder,
810 args.num_class_images != 0,
811 args.prior_loss_weight, 802 args.prior_loss_weight,
812 args.seed, 803 args.seed,
813 ) 804 )
@@ -825,7 +816,8 @@ def main():
825 scheduler=sample_scheduler, 816 scheduler=sample_scheduler,
826 placeholder_tokens=args.placeholder_tokens, 817 placeholder_tokens=args.placeholder_tokens,
827 placeholder_token_ids=placeholder_token_ids, 818 placeholder_token_ids=placeholder_token_ids,
828 output_dir=basepath, 819 output_dir=output_dir,
820 sample_steps=args.sample_steps,
829 sample_image_size=args.sample_image_size, 821 sample_image_size=args.sample_image_size,
830 sample_batch_size=args.sample_batch_size, 822 sample_batch_size=args.sample_batch_size,
831 sample_batches=args.sample_batches, 823 sample_batches=args.sample_batches,
@@ -849,7 +841,7 @@ def main():
849 ) 841 )
850 lr_finder.run(num_epochs=100, end_lr=1e3) 842 lr_finder.run(num_epochs=100, end_lr=1e3)
851 843
852 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 844 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
853 plt.close() 845 plt.close()
854 else: 846 else:
855 train_loop( 847 train_loop(
@@ -862,7 +854,6 @@ def main():
862 val_dataloader=val_dataloader, 854 val_dataloader=val_dataloader,
863 loss_step=loss_step_, 855 loss_step=loss_step_,
864 sample_frequency=args.sample_frequency, 856 sample_frequency=args.sample_frequency,
865 sample_steps=args.sample_steps,
866 checkpoint_frequency=args.checkpoint_frequency, 857 checkpoint_frequency=args.checkpoint_frequency,
867 global_step_offset=global_step_offset, 858 global_step_offset=global_step_offset,
868 num_epochs=args.num_train_epochs, 859 num_epochs=args.num_train_epochs,