summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 10:12:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 10:12:04 +0100
commit34648b763fa60e3161a5b5f1243ed1b5c3b0188e (patch)
tree4c2b8104a8d1af26955561959591249d9281a02f /train_ti.py
parentAdded functional trainer (diff)
downloadtextual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.tar.gz
textual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.tar.bz2
textual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.zip
Added functional TI strategy
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py108
1 files changed, 30 insertions, 78 deletions
diff --git a/train_ti.py b/train_ti.py
index 97e4e72..2fd325b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,8 @@ from slugify import slugify
17from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer 19from trainer_old.base import Checkpointer
20from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy
21from training.optimization import get_scheduler 22from training.optimization import get_scheduler
22from training.lr import LRFinder 23from training.lr import LRFinder
23from training.util import EMAModel, save_args 24from training.util import EMAModel, save_args
@@ -387,6 +388,11 @@ def parse_args():
387 help="The weight of prior preservation loss." 388 help="The weight of prior preservation loss."
388 ) 389 )
389 parser.add_argument( 390 parser.add_argument(
391 "--use_emb_decay",
392 action="store_true",
393 help="Whether to use embedding decay."
394 )
395 parser.add_argument(
390 "--emb_decay_target", 396 "--emb_decay_target",
391 default=0.4, 397 default=0.4,
392 type=float, 398 type=float,
@@ -591,14 +597,6 @@ def main():
591 else: 597 else:
592 ema_embeddings = None 598 ema_embeddings = None
593 599
594 vae.requires_grad_(False)
595 unet.requires_grad_(False)
596
597 text_encoder.text_model.encoder.requires_grad_(False)
598 text_encoder.text_model.final_layer_norm.requires_grad_(False)
599 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
600 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
601
602 if args.scale_lr: 600 if args.scale_lr:
603 args.learning_rate = ( 601 args.learning_rate = (
604 args.learning_rate * args.gradient_accumulation_steps * 602 args.learning_rate * args.gradient_accumulation_steps *
@@ -719,73 +717,36 @@ def main():
719 seed=args.seed, 717 seed=args.seed,
720 ) 718 )
721 719
722 def on_prepare(): 720 strategy = textual_inversion_strategy(
723 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
724
725 if args.gradient_checkpointing:
726 unet.train()
727
728 @contextmanager
729 def on_train(epoch: int):
730 try:
731 tokenizer.train()
732 yield
733 finally:
734 pass
735
736 @contextmanager
737 def on_eval():
738 try:
739 tokenizer.eval()
740
741 ema_context = ema_embeddings.apply_temporary(
742 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
743
744 with ema_context:
745 yield
746 finally:
747 pass
748
749 @torch.no_grad()
750 def on_after_optimize(lr: float):
751 if args.emb_decay_factor != 0:
752 text_encoder.text_model.embeddings.normalize(
753 args.emb_decay_target,
754 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
755 )
756
757 if args.use_ema:
758 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
759
760 def on_log():
761 if args.use_ema:
762 return {"ema_decay": ema_embeddings.decay}
763 return {}
764
765 checkpointer = TextualInversionCheckpointer(
766 dtype=weight_dtype,
767 train_dataloader=train_dataloader,
768 val_dataloader=val_dataloader,
769 accelerator=accelerator, 721 accelerator=accelerator,
770 vae=vae,
771 unet=unet, 722 unet=unet,
772 tokenizer=tokenizer,
773 text_encoder=text_encoder, 723 text_encoder=text_encoder,
774 ema_embeddings=ema_embeddings, 724 tokenizer=tokenizer,
725 vae=vae,
775 sample_scheduler=sample_scheduler, 726 sample_scheduler=sample_scheduler,
727 train_dataloader=train_dataloader,
728 val_dataloader=val_dataloader,
729 dtype=weight_dtype,
730 output_dir=output_dir,
731 seed=args.seed,
776 placeholder_tokens=args.placeholder_tokens, 732 placeholder_tokens=args.placeholder_tokens,
777 placeholder_token_ids=placeholder_token_ids, 733 placeholder_token_ids=placeholder_token_ids,
778 output_dir=output_dir, 734 learning_rate=args.learning_rate,
779 sample_steps=args.sample_steps, 735 gradient_checkpointing=args.gradient_checkpointing,
780 sample_image_size=args.sample_image_size, 736 use_emb_decay=args.use_emb_decay,
737 emb_decay_target=args.emb_decay_target,
738 emb_decay_factor=args.emb_decay_factor,
739 emb_decay_start=args.emb_decay_start,
740 use_ema=args.use_ema,
741 ema_inv_gamma=args.ema_inv_gamma,
742 ema_power=args.ema_power,
743 ema_max_decay=args.ema_max_decay,
781 sample_batch_size=args.sample_batch_size, 744 sample_batch_size=args.sample_batch_size,
782 sample_batches=args.sample_batches, 745 sample_num_batches=args.sample_batches,
783 seed=args.seed 746 sample_num_steps=args.sample_steps,
747 sample_image_size=args.sample_image_size,
784 ) 748 )
785 749
786 if accelerator.is_main_process:
787 accelerator.init_trackers("textual_inversion")
788
789 if args.find_lr: 750 if args.find_lr:
790 lr_finder = LRFinder( 751 lr_finder = LRFinder(
791 accelerator=accelerator, 752 accelerator=accelerator,
@@ -793,10 +754,7 @@ def main():
793 model=text_encoder, 754 model=text_encoder,
794 train_dataloader=train_dataloader, 755 train_dataloader=train_dataloader,
795 val_dataloader=val_dataloader, 756 val_dataloader=val_dataloader,
796 loss_step=loss_step_, 757 **strategy,
797 on_train=on_train,
798 on_eval=on_eval,
799 on_after_optimize=on_after_optimize,
800 ) 758 )
801 lr_finder.run(num_epochs=100, end_lr=1e3) 759 lr_finder.run(num_epochs=100, end_lr=1e3)
802 760
@@ -811,13 +769,7 @@ def main():
811 checkpoint_frequency=args.checkpoint_frequency, 769 checkpoint_frequency=args.checkpoint_frequency,
812 global_step_offset=global_step_offset, 770 global_step_offset=global_step_offset,
813 prior_loss_weight=args.prior_loss_weight, 771 prior_loss_weight=args.prior_loss_weight,
814 on_prepare=on_prepare, 772 **strategy,
815 on_log=on_log,
816 on_train=on_train,
817 on_after_optimize=on_after_optimize,
818 on_eval=on_eval,
819 on_sample=checkpointer.save_samples,
820 on_checkpoint=checkpointer.checkpoint,
821 ) 773 )
822 774
823 775