From f4f90c487cbc247952689e906519d8e2eb21da99 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 09:07:18 +0100 Subject: Add contextmanager to EMAModel to apply weights temporarily --- train_ti.py | 57 ++++++++++++++++++++++++--------------------------------- 1 file changed, 24 insertions(+), 33 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 2f13128..aa2bf02 100644 --- a/train_ti.py +++ b/train_ti.py @@ -5,6 +5,7 @@ import logging import copy from pathlib import Path from functools import partial +from contextlib import nullcontext import torch import torch.utils.checkpoint @@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - if self.ema_embeddings is not None: - orig_weights = text_encoder.text_model.embeddings.temp_token_embedding - ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) - self.ema_embeddings.copy_to(ema_weights.parameters()) - text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - for (token, ids) in zip(self.placeholder_token, self.new_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - if self.ema_embeddings is not None: - text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + with ema_context: + for (token, ids) in zip(self.placeholder_token, self.new_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) del text_encoder @@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase): def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - if self.ema_embeddings is not None: - orig_weights = text_encoder.text_model.embeddings.temp_token_embedding - ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) - self.ema_embeddings.copy_to(ema_weights.parameters()) - text_encoder.text_model.embeddings.temp_token_embedding = ema_weights - - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) - text_encoder.to(dtype=orig_dtype) + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - if self.ema_embeddings is not None: - text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + text_encoder.to(dtype=orig_dtype) del text_encoder del pipeline -- cgit v1.2.3-54-g00ecf