diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/train_ti.py b/train_ti.py index 928b721..8631892 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -397,7 +397,7 @@ def parse_args(): | |||
397 | ) | 397 | ) |
398 | parser.add_argument( | 398 | parser.add_argument( |
399 | "--emb_decay_factor", | 399 | "--emb_decay_factor", |
400 | default=1, | 400 | default=0, |
401 | type=float, | 401 | type=float, |
402 | help="Embedding decay factor." | 402 | help="Embedding decay factor." |
403 | ) | 403 | ) |
@@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase): | |||
532 | 532 | ||
533 | @torch.no_grad() | 533 | @torch.no_grad() |
534 | def save_samples(self, step): | 534 | def save_samples(self, step): |
535 | unet = self.accelerator.unwrap_model(self.unet) | ||
535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 536 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
536 | 537 | ||
537 | ema_context = self.ema_embeddings.apply_temporary( | 538 | ema_context = self.ema_embeddings.apply_temporary( |
538 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() | 539 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
539 | 540 | ||
540 | with ema_context: | 541 | with ema_context: |
541 | orig_dtype = text_encoder.dtype | 542 | orig_unet_dtype = unet.dtype |
543 | orig_text_encoder_dtype = text_encoder.dtype | ||
544 | |||
545 | unet.to(dtype=self.weight_dtype) | ||
542 | text_encoder.to(dtype=self.weight_dtype) | 546 | text_encoder.to(dtype=self.weight_dtype) |
543 | 547 | ||
544 | pipeline = VlpnStableDiffusion( | 548 | pipeline = VlpnStableDiffusion( |
@@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase): | |||
552 | 556 | ||
553 | super().save_samples(pipeline, step) | 557 | super().save_samples(pipeline, step) |
554 | 558 | ||
555 | text_encoder.to(dtype=orig_dtype) | 559 | unet.to(dtype=orig_unet_dtype) |
560 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
556 | 561 | ||
557 | del text_encoder | 562 | del text_encoder |
558 | del pipeline | 563 | del pipeline |
@@ -742,20 +747,17 @@ def main(): | |||
742 | warmup_epochs=args.lr_warmup_epochs, | 747 | warmup_epochs=args.lr_warmup_epochs, |
743 | ) | 748 | ) |
744 | 749 | ||
745 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 750 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
746 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 751 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
747 | ) | 752 | ) |
748 | 753 | ||
749 | vae.to(accelerator.device, dtype=weight_dtype) | 754 | vae.to(accelerator.device, dtype=weight_dtype) |
750 | unet.to(accelerator.device, dtype=weight_dtype) | ||
751 | 755 | ||
752 | if args.use_ema: | 756 | if args.use_ema: |
753 | ema_embeddings.to(accelerator.device) | 757 | ema_embeddings.to(accelerator.device) |
754 | 758 | ||
755 | if args.gradient_checkpointing: | 759 | if args.gradient_checkpointing: |
756 | unet.train() | 760 | unet.train() |
757 | else: | ||
758 | unet.eval() | ||
759 | 761 | ||
760 | @contextmanager | 762 | @contextmanager |
761 | def on_train(epoch: int): | 763 | def on_train(epoch: int): |
@@ -780,10 +782,11 @@ def main(): | |||
780 | 782 | ||
781 | @torch.no_grad() | 783 | @torch.no_grad() |
782 | def on_after_optimize(lr: float): | 784 | def on_after_optimize(lr: float): |
783 | text_encoder.text_model.embeddings.normalize( | 785 | if args.emb_decay_factor != 0: |
784 | args.emb_decay_target, | 786 | text_encoder.text_model.embeddings.normalize( |
785 | min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) | 787 | args.emb_decay_target, |
786 | ) | 788 | min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) |
789 | ) | ||
787 | 790 | ||
788 | if args.use_ema: | 791 | if args.use_ema: |
789 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 792 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |