summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 17:09:01 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 17:09:01 +0100
commit36440e48ce279872d6e736bcb1bf57d13da73a11 (patch)
tree8ba9593d8a887517c70b01932c137c9c3f759e8f /train_dreambooth.py
parentMore training adjustments (diff)
downloadtextual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.gz
textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.tar.bz2
textual-inversion-diff-36440e48ce279872d6e736bcb1bf57d13da73a11.zip
Moved multi-TI code from Dreambooth to TI script
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py135
1 files changed, 2 insertions, 133 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1dc41b1..6511f9b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -200,23 +200,6 @@ def parse_args():
200 default=100 200 default=100
201 ) 201 )
202 parser.add_argument( 202 parser.add_argument(
203 "--ti_data_template",
204 type=str,
205 nargs='*',
206 default=[],
207 )
208 parser.add_argument(
209 "--ti_num_train_epochs",
210 type=int,
211 default=10
212 )
213 parser.add_argument(
214 "--ti_batch_size",
215 type=int,
216 default=1,
217 help="Batch size (per device) for the training dataloader."
218 )
219 parser.add_argument(
220 "--max_train_steps", 203 "--max_train_steps",
221 type=int, 204 type=int,
222 default=None, 205 default=None,
@@ -245,12 +228,6 @@ def parse_args():
245 help="Initial learning rate (after the potential warmup period) to use.", 228 help="Initial learning rate (after the potential warmup period) to use.",
246 ) 229 )
247 parser.add_argument( 230 parser.add_argument(
248 "--ti_learning_rate",
249 type=float,
250 default=1e-2,
251 help="Initial learning rate (after the potential warmup period) to use.",
252 )
253 parser.add_argument(
254 "--scale_lr", 231 "--scale_lr",
255 action="store_true", 232 action="store_true",
256 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 233 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
@@ -482,12 +459,6 @@ def parse_args():
482 if len(args.placeholder_tokens) != len(args.num_vectors): 459 if len(args.placeholder_tokens) != len(args.num_vectors):
483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 460 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
484 461
485 if isinstance(args.ti_data_template, str):
486 args.ti_data_template = [args.ti_data_template]
487
488 if len(args.ti_data_template) == 0:
489 raise ValueError("You must specify --ti_data_template")
490
491 if isinstance(args.collection, str): 462 if isinstance(args.collection, str):
492 args.collection = [args.collection] 463 args.collection = [args.collection]
493 464
@@ -521,8 +492,6 @@ def main():
521 492
522 set_seed(args.seed) 493 set_seed(args.seed)
523 494
524 seed_generator = torch.Generator().manual_seed(args.seed)
525
526 save_args(output_dir, args) 495 save_args(output_dir, args)
527 496
528 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 497 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -583,107 +552,6 @@ def main():
583 prior_loss_weight=args.prior_loss_weight, 552 prior_loss_weight=args.prior_loss_weight,
584 ) 553 )
585 554
586 # Initial TI
587
588 print("Phase 1: Textual Inversion")
589
590 cur_dir = output_dir.joinpath("1-ti")
591 cur_dir.mkdir(parents=True, exist_ok=True)
592
593 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
594 range(len(args.placeholder_tokens)),
595 args.placeholder_tokens,
596 args.initializer_tokens,
597 args.num_vectors,
598 args.ti_data_template
599 ):
600 cur_subdir = cur_dir.joinpath(placeholder_token)
601 cur_subdir.mkdir(parents=True, exist_ok=True)
602
603 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
604 tokenizer=tokenizer,
605 embeddings=embeddings,
606 placeholder_tokens=[placeholder_token],
607 initializer_tokens=[initializer_token],
608 num_vectors=[num_vectors]
609 )
610
611 print(
612 f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
613
614 args.seed = seed_generator.seed()
615
616 datamodule = VlpnDataModule(
617 data_file=args.train_data_file,
618 batch_size=args.ti_batch_size,
619 tokenizer=tokenizer,
620 class_subdir=args.class_image_dir,
621 num_class_images=args.num_class_images,
622 size=args.resolution,
623 shuffle=not args.no_tag_shuffle,
624 template_key=data_template,
625 valid_set_size=1,
626 train_set_pad=args.train_set_pad,
627 valid_set_pad=args.valid_set_pad,
628 seed=args.seed,
629 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
630 dtype=weight_dtype
631 )
632 datamodule.setup()
633
634 optimizer = optimizer_class(
635 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
636 lr=args.ti_learning_rate,
637 betas=(args.adam_beta1, args.adam_beta2),
638 weight_decay=0.0,
639 eps=args.adam_epsilon,
640 )
641
642 lr_scheduler = get_scheduler(
643 "one_cycle",
644 optimizer=optimizer,
645 num_training_steps_per_epoch=len(datamodule.train_dataloader),
646 gradient_accumulation_steps=args.gradient_accumulation_steps,
647 train_epochs=args.ti_num_train_epochs,
648 )
649
650 trainer(
651 callbacks_fn=textual_inversion_strategy,
652 project="textual_inversion",
653 train_dataloader=datamodule.train_dataloader,
654 val_dataloader=datamodule.val_dataloader,
655 seed=args.seed,
656 optimizer=optimizer,
657 lr_scheduler=lr_scheduler,
658 num_train_epochs=args.ti_num_train_epochs,
659 sample_frequency=args.ti_num_train_epochs // 5,
660 checkpoint_frequency=9999999,
661 # --
662 tokenizer=tokenizer,
663 sample_scheduler=sample_scheduler,
664 output_dir=cur_subdir,
665 placeholder_tokens=[placeholder_token],
666 placeholder_token_ids=placeholder_token_ids,
667 learning_rate=args.ti_learning_rate,
668 gradient_checkpointing=args.gradient_checkpointing,
669 use_emb_decay=True,
670 sample_batch_size=args.sample_batch_size,
671 sample_num_batches=args.sample_batches,
672 sample_num_steps=args.sample_steps,
673 sample_image_size=args.sample_image_size,
674 )
675
676 embeddings.persist()
677
678 # Dreambooth
679
680 print("Phase 2: Dreambooth")
681
682 cur_dir = output_dir.joinpath("2-db")
683 cur_dir.mkdir(parents=True, exist_ok=True)
684
685 args.seed = seed_generator.seed()
686
687 datamodule = VlpnDataModule( 555 datamodule = VlpnDataModule(
688 data_file=args.train_data_file, 556 data_file=args.train_data_file,
689 batch_size=args.train_batch_size, 557 batch_size=args.train_batch_size,
@@ -746,12 +614,13 @@ def main():
746 seed=args.seed, 614 seed=args.seed,
747 optimizer=optimizer, 615 optimizer=optimizer,
748 lr_scheduler=lr_scheduler, 616 lr_scheduler=lr_scheduler,
617 prepare_unet=True,
749 num_train_epochs=args.num_train_epochs, 618 num_train_epochs=args.num_train_epochs,
750 sample_frequency=args.sample_frequency, 619 sample_frequency=args.sample_frequency,
751 # -- 620 # --
752 tokenizer=tokenizer, 621 tokenizer=tokenizer,
753 sample_scheduler=sample_scheduler, 622 sample_scheduler=sample_scheduler,
754 output_dir=cur_dir, 623 output_dir=output_dir,
755 train_text_encoder_epochs=args.train_text_encoder_epochs, 624 train_text_encoder_epochs=args.train_text_encoder_epochs,
756 max_grad_norm=args.max_grad_norm, 625 max_grad_norm=args.max_grad_norm,
757 use_ema=args.use_ema, 626 use_ema=args.use_ema,