summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py27
1 files changed, 15 insertions, 12 deletions
diff --git a/train_ti.py b/train_ti.py
index 928b721..8631892 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -397,7 +397,7 @@ def parse_args():
397 ) 397 )
398 parser.add_argument( 398 parser.add_argument(
399 "--emb_decay_factor", 399 "--emb_decay_factor",
400 default=1, 400 default=0,
401 type=float, 401 type=float,
402 help="Embedding decay factor." 402 help="Embedding decay factor."
403 ) 403 )
@@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase):
532 532
533 @torch.no_grad() 533 @torch.no_grad()
534 def save_samples(self, step): 534 def save_samples(self, step):
535 unet = self.accelerator.unwrap_model(self.unet)
535 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 536 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
536 537
537 ema_context = self.ema_embeddings.apply_temporary( 538 ema_context = self.ema_embeddings.apply_temporary(
538 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 539 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
539 540
540 with ema_context: 541 with ema_context:
541 orig_dtype = text_encoder.dtype 542 orig_unet_dtype = unet.dtype
543 orig_text_encoder_dtype = text_encoder.dtype
544
545 unet.to(dtype=self.weight_dtype)
542 text_encoder.to(dtype=self.weight_dtype) 546 text_encoder.to(dtype=self.weight_dtype)
543 547
544 pipeline = VlpnStableDiffusion( 548 pipeline = VlpnStableDiffusion(
@@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase):
552 556
553 super().save_samples(pipeline, step) 557 super().save_samples(pipeline, step)
554 558
555 text_encoder.to(dtype=orig_dtype) 559 unet.to(dtype=orig_unet_dtype)
560 text_encoder.to(dtype=orig_text_encoder_dtype)
556 561
557 del text_encoder 562 del text_encoder
558 del pipeline 563 del pipeline
@@ -742,20 +747,17 @@ def main():
742 warmup_epochs=args.lr_warmup_epochs, 747 warmup_epochs=args.lr_warmup_epochs,
743 ) 748 )
744 749
745 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 750 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
746 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 751 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
747 ) 752 )
748 753
749 vae.to(accelerator.device, dtype=weight_dtype) 754 vae.to(accelerator.device, dtype=weight_dtype)
750 unet.to(accelerator.device, dtype=weight_dtype)
751 755
752 if args.use_ema: 756 if args.use_ema:
753 ema_embeddings.to(accelerator.device) 757 ema_embeddings.to(accelerator.device)
754 758
755 if args.gradient_checkpointing: 759 if args.gradient_checkpointing:
756 unet.train() 760 unet.train()
757 else:
758 unet.eval()
759 761
760 @contextmanager 762 @contextmanager
761 def on_train(epoch: int): 763 def on_train(epoch: int):
@@ -780,10 +782,11 @@ def main():
780 782
781 @torch.no_grad() 783 @torch.no_grad()
782 def on_after_optimize(lr: float): 784 def on_after_optimize(lr: float):
783 text_encoder.text_model.embeddings.normalize( 785 if args.emb_decay_factor != 0:
784 args.emb_decay_target, 786 text_encoder.text_model.embeddings.normalize(
785 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) 787 args.emb_decay_target,
786 ) 788 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
789 )
787 790
788 if args.use_ema: 791 if args.use_ema:
789 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 792 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())