summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py57
-rw-r--r--training/util.py12
2 files changed, 36 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py
index 2f13128..aa2bf02 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -5,6 +5,7 @@ import logging
5import copy 5import copy
6from pathlib import Path 6from pathlib import Path
7from functools import partial 7from functools import partial
8from contextlib import nullcontext
8 9
9import torch 10import torch
10import torch.utils.checkpoint 11import torch.utils.checkpoint
@@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase):
509 510
510 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 511 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
511 512
512 if self.ema_embeddings is not None: 513 ema_context = self.ema_embeddings.apply_temporary(
513 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding 514 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
514 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
515 self.ema_embeddings.copy_to(ema_weights.parameters())
516 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
517 515
518 for (token, ids) in zip(self.placeholder_token, self.new_ids): 516 with ema_context:
519 text_encoder.text_model.embeddings.save_embed( 517 for (token, ids) in zip(self.placeholder_token, self.new_ids):
520 ids, 518 text_encoder.text_model.embeddings.save_embed(
521 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 519 ids,
522 ) 520 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
523 521 )
524 if self.ema_embeddings is not None:
525 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
526 522
527 del text_encoder 523 del text_encoder
528 524
@@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase):
530 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 526 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
531 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 527 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
532 528
533 if self.ema_embeddings is not None: 529 ema_context = self.ema_embeddings.apply_temporary(
534 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding 530 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext()
535 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
536 self.ema_embeddings.copy_to(ema_weights.parameters())
537 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
538
539 orig_dtype = text_encoder.dtype
540 text_encoder.to(dtype=self.weight_dtype)
541 531
542 pipeline = VlpnStableDiffusion( 532 with ema_context:
543 text_encoder=text_encoder, 533 orig_dtype = text_encoder.dtype
544 vae=self.vae, 534 text_encoder.to(dtype=self.weight_dtype)
545 unet=self.unet,
546 tokenizer=self.tokenizer,
547 scheduler=self.scheduler,
548 ).to(self.accelerator.device)
549 pipeline.set_progress_bar_config(dynamic_ncols=True)
550 535
551 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 536 pipeline = VlpnStableDiffusion(
537 text_encoder=text_encoder,
538 vae=self.vae,
539 unet=self.unet,
540 tokenizer=self.tokenizer,
541 scheduler=self.scheduler,
542 ).to(self.accelerator.device)
543 pipeline.set_progress_bar_config(dynamic_ncols=True)
552 544
553 text_encoder.to(dtype=orig_dtype) 545 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
554 546
555 if self.ema_embeddings is not None: 547 text_encoder.to(dtype=orig_dtype)
556 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
557 548
558 del text_encoder 549 del text_encoder
559 del pipeline 550 del pipeline
diff --git a/training/util.py b/training/util.py
index 93b6248..6f1e85a 100644
--- a/training/util.py
+++ b/training/util.py
@@ -2,6 +2,7 @@ from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable 4from typing import Iterable
5from contextlib import contextmanager
5 6
6import torch 7import torch
7from PIL import Image 8from PIL import Image
@@ -259,3 +260,14 @@ class EMAModel:
259 raise ValueError("collected_params must all be Tensors") 260 raise ValueError("collected_params must all be Tensors")
260 if len(self.collected_params) != len(self.shadow_params): 261 if len(self.collected_params) != len(self.shadow_params):
261 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263
264 @contextmanager
265 def apply_temporary(self, parameters):
266 try:
267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters]
269 self.copy_to(parameters)
270 yield
271 finally:
272 for s_param, param in zip(original_params, parameters):
273 param.data.copy_(s_param.data)