summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 10:02:30 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 10:02:30 +0100
commita2c240c8c55dfe930657f66372975d6f26feb168 (patch)
tree61c22b830098a6a28f885d9a0964b02a7f429e30
parentFix (diff)
downloadtextual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.gz
textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.bz2
textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.zip
TI: Prepare UNet with Accelerate as well
-rw-r--r--train_ti.py27
-rw-r--r--training/common.py37
-rw-r--r--training/util.py16
3 files changed, 41 insertions, 39 deletions
diff --git a/train_ti.py b/train_ti.py
index 928b721..8631892 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -397,7 +397,7 @@ def parse_args():
397 ) 397 )
398 parser.add_argument( 398 parser.add_argument(
399 "--emb_decay_factor", 399 "--emb_decay_factor",
400 default=1, 400 default=0,
401 type=float, 401 type=float,
402 help="Embedding decay factor." 402 help="Embedding decay factor."
403 ) 403 )
@@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase):
532 532
533 @torch.no_grad() 533 @torch.no_grad()
534 def save_samples(self, step): 534 def save_samples(self, step):
535 unet = self.accelerator.unwrap_model(self.unet)
535 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 536 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
536 537
537 ema_context = self.ema_embeddings.apply_temporary( 538 ema_context = self.ema_embeddings.apply_temporary(
538 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 539 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
539 540
540 with ema_context: 541 with ema_context:
541 orig_dtype = text_encoder.dtype 542 orig_unet_dtype = unet.dtype
543 orig_text_encoder_dtype = text_encoder.dtype
544
545 unet.to(dtype=self.weight_dtype)
542 text_encoder.to(dtype=self.weight_dtype) 546 text_encoder.to(dtype=self.weight_dtype)
543 547
544 pipeline = VlpnStableDiffusion( 548 pipeline = VlpnStableDiffusion(
@@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase):
552 556
553 super().save_samples(pipeline, step) 557 super().save_samples(pipeline, step)
554 558
555 text_encoder.to(dtype=orig_dtype) 559 unet.to(dtype=orig_unet_dtype)
560 text_encoder.to(dtype=orig_text_encoder_dtype)
556 561
557 del text_encoder 562 del text_encoder
558 del pipeline 563 del pipeline
@@ -742,20 +747,17 @@ def main():
742 warmup_epochs=args.lr_warmup_epochs, 747 warmup_epochs=args.lr_warmup_epochs,
743 ) 748 )
744 749
745 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 750 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
746 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 751 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
747 ) 752 )
748 753
749 vae.to(accelerator.device, dtype=weight_dtype) 754 vae.to(accelerator.device, dtype=weight_dtype)
750 unet.to(accelerator.device, dtype=weight_dtype)
751 755
752 if args.use_ema: 756 if args.use_ema:
753 ema_embeddings.to(accelerator.device) 757 ema_embeddings.to(accelerator.device)
754 758
755 if args.gradient_checkpointing: 759 if args.gradient_checkpointing:
756 unet.train() 760 unet.train()
757 else:
758 unet.eval()
759 761
760 @contextmanager 762 @contextmanager
761 def on_train(epoch: int): 763 def on_train(epoch: int):
@@ -780,10 +782,11 @@ def main():
780 782
781 @torch.no_grad() 783 @torch.no_grad()
782 def on_after_optimize(lr: float): 784 def on_after_optimize(lr: float):
783 text_encoder.text_model.embeddings.normalize( 785 if args.emb_decay_factor != 0:
784 args.emb_decay_target, 786 text_encoder.text_model.embeddings.normalize(
785 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) 787 args.emb_decay_target,
786 ) 788 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
789 )
787 790
788 if args.use_ema: 791 if args.use_ema:
789 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 792 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
diff --git a/training/common.py b/training/common.py
index 8083137..5d1e3f9 100644
--- a/training/common.py
+++ b/training/common.py
@@ -316,30 +316,29 @@ def train_loop(
316 cur_loss_val = AverageMeter() 316 cur_loss_val = AverageMeter()
317 cur_acc_val = AverageMeter() 317 cur_acc_val = AverageMeter()
318 318
319 with torch.inference_mode(): 319 with torch.inference_mode(), on_eval():
320 with on_eval(): 320 for step, batch in enumerate(val_dataloader):
321 for step, batch in enumerate(val_dataloader): 321 loss, acc, bsz = loss_step(step, batch, True)
322 loss, acc, bsz = loss_step(step, batch, True)
323 322
324 loss = loss.detach_() 323 loss = loss.detach_()
325 acc = acc.detach_() 324 acc = acc.detach_()
326 325
327 cur_loss_val.update(loss, bsz) 326 cur_loss_val.update(loss, bsz)
328 cur_acc_val.update(acc, bsz) 327 cur_acc_val.update(acc, bsz)
329 328
330 avg_loss_val.update(loss, bsz) 329 avg_loss_val.update(loss, bsz)
331 avg_acc_val.update(acc, bsz) 330 avg_acc_val.update(acc, bsz)
332 331
333 local_progress_bar.update(1) 332 local_progress_bar.update(1)
334 global_progress_bar.update(1) 333 global_progress_bar.update(1)
335 334
336 logs = { 335 logs = {
337 "val/loss": avg_loss_val.avg.item(), 336 "val/loss": avg_loss_val.avg.item(),
338 "val/acc": avg_acc_val.avg.item(), 337 "val/acc": avg_acc_val.avg.item(),
339 "val/cur_loss": loss.item(), 338 "val/cur_loss": loss.item(),
340 "val/cur_acc": acc.item(), 339 "val/cur_acc": acc.item(),
341 } 340 }
342 local_progress_bar.set_postfix(**logs) 341 local_progress_bar.set_postfix(**logs)
343 342
344 logs["val/cur_loss"] = cur_loss_val.avg.item() 343 logs["val/cur_loss"] = cur_loss_val.avg.item()
345 logs["val/cur_acc"] = cur_acc_val.avg.item() 344 logs["val/cur_acc"] = cur_acc_val.avg.item()
diff --git a/training/util.py b/training/util.py
index 1008021..781cf04 100644
--- a/training/util.py
+++ b/training/util.py
@@ -134,11 +134,11 @@ class EMAModel:
134 def __init__( 134 def __init__(
135 self, 135 self,
136 parameters: Iterable[torch.nn.Parameter], 136 parameters: Iterable[torch.nn.Parameter],
137 update_after_step=0, 137 update_after_step: int = 0,
138 inv_gamma=1.0, 138 inv_gamma: float = 1.0,
139 power=2 / 3, 139 power: float = 2 / 3,
140 min_value=0.0, 140 min_value: float = 0.0,
141 max_value=0.9999, 141 max_value: float = 0.9999,
142 ): 142 ):
143 """ 143 """
144 @crowsonkb's notes on EMA Warmup: 144 @crowsonkb's notes on EMA Warmup:
@@ -165,7 +165,7 @@ class EMAModel:
165 self.decay = 0.0 165 self.decay = 0.0
166 self.optimization_step = 0 166 self.optimization_step = 0
167 167
168 def get_decay(self, optimization_step): 168 def get_decay(self, optimization_step: int):
169 """ 169 """
170 Compute the decay factor for the exponential moving average. 170 Compute the decay factor for the exponential moving average.
171 """ 171 """
@@ -276,5 +276,5 @@ class EMAModel:
276 self.copy_to(parameters) 276 self.copy_to(parameters)
277 yield 277 yield
278 finally: 278 finally:
279 for s_param, param in zip(original_params, parameters): 279 for o_param, param in zip(original_params, parameters):
280 param.data.copy_(s_param.data) 280 param.data.copy_(o_param.data)