summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 09:07:18 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 09:07:18 +0100
commitf4f90c487cbc247952689e906519d8e2eb21da99 (patch)
treefc308cdcf02c36437e8017fab5961294f86930fe /train_ti.py
parentLog EMA decay (diff)
downloadtextual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz
textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2
textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip
Add contextmanager to EMAModel to apply weights temporarily
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py57
1 files changed, 24 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py
index 2f13128..aa2bf02 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -5,6 +5,7 @@ import logging
5import copy 5import copy
6from pathlib import Path 6from pathlib import Path
7from functools import partial 7from functools import partial
8from contextlib import nullcontext
8 9
9import torch 10import torch
10import torch.utils.checkpoint 11import torch.utils.checkpoint
@@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase):
509 510
510 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 511 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
511 512
512 if self.ema_embeddings is not None: 513 ema_context = self.ema_embeddings.apply_temporary(
513 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding 514 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
514 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
515 self.ema_embeddings.copy_to(ema_weights.parameters())
516 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
517 515
518 for (token, ids) in zip(self.placeholder_token, self.new_ids): 516 with ema_context:
519 text_encoder.text_model.embeddings.save_embed( 517 for (token, ids) in zip(self.placeholder_token, self.new_ids):
520 ids, 518 text_encoder.text_model.embeddings.save_embed(
521 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 519 ids,
522 ) 520 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
523 521 )
524 if self.ema_embeddings is not None:
525 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
526 522
527 del text_encoder 523 del text_encoder
528 524
@@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase):
530 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 526 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
531 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 527 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
532 528
533 if self.ema_embeddings is not None: 529 ema_context = self.ema_embeddings.apply_temporary(
534 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding 530 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
535 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
536 self.ema_embeddings.copy_to(ema_weights.parameters())
537 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
538
539 orig_dtype = text_encoder.dtype
540 text_encoder.to(dtype=self.weight_dtype)
541 531
542 pipeline = VlpnStableDiffusion( 532 with ema_context:
543 text_encoder=text_encoder, 533 orig_dtype = text_encoder.dtype
544 vae=self.vae, 534 text_encoder.to(dtype=self.weight_dtype)
545 unet=self.unet,
546 tokenizer=self.tokenizer,
547 scheduler=self.scheduler,
548 ).to(self.accelerator.device)
549 pipeline.set_progress_bar_config(dynamic_ncols=True)
550 535
551 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 536 pipeline = VlpnStableDiffusion(
537 text_encoder=text_encoder,
538 vae=self.vae,
539 unet=self.unet,
540 tokenizer=self.tokenizer,
541 scheduler=self.scheduler,
542 ).to(self.accelerator.device)
543 pipeline.set_progress_bar_config(dynamic_ncols=True)
552 544
553 text_encoder.to(dtype=orig_dtype) 545 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
554 546
555 if self.ema_embeddings is not None: 547 text_encoder.to(dtype=orig_dtype)
556 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
557 548
558 del text_encoder 549 del text_encoder
559 del pipeline 550 del pipeline