From f4f90c487cbc247952689e906519d8e2eb21da99 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 09:07:18 +0100 Subject: Add contextmanager to EMAModel to apply weights temporarily --- train_ti.py | 57 ++++++++++++++++++++++++-------------------------------- training/util.py | 12 ++++++++++++ 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/train_ti.py b/train_ti.py index 2f13128..aa2bf02 100644 --- a/train_ti.py +++ b/train_ti.py @@ -5,6 +5,7 @@ import logging import copy from pathlib import Path from functools import partial +from contextlib import nullcontext import torch import torch.utils.checkpoint @@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - if self.ema_embeddings is not None: - orig_weights = text_encoder.text_model.embeddings.temp_token_embedding - ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) - self.ema_embeddings.copy_to(ema_weights.parameters()) - text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - for (token, ids) in zip(self.placeholder_token, self.new_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - if self.ema_embeddings is not None: - text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + with ema_context: + for (token, ids) in zip(self.placeholder_token, self.new_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) del text_encoder @@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase): def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - if self.ema_embeddings is not None: - orig_weights = text_encoder.text_model.embeddings.temp_token_embedding - ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) - self.ema_embeddings.copy_to(ema_weights.parameters()) - text_encoder.text_model.embeddings.temp_token_embedding = ema_weights - - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) - text_encoder.to(dtype=orig_dtype) + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - if self.ema_embeddings is not None: - text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + text_encoder.to(dtype=orig_dtype) del text_encoder del pipeline diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py @@ -2,6 +2,7 @@ from pathlib import Path import json import copy from typing import Iterable +from contextlib import contextmanager import torch from PIL import Image @@ -259,3 +260,14 @@ class EMAModel: raise ValueError("collected_params must all be Tensors") if len(self.collected_params) != len(self.shadow_params): raise ValueError("collected_params and shadow_params must have the same length") + + @contextmanager + def apply_temporary(self, parameters): + try: + parameters = list(parameters) + original_params = [p.clone() for p in parameters] + self.copy_to(parameters) + yield + finally: + for s_param, param in zip(original_params, parameters): + param.data.copy_(s_param.data) -- cgit v1.2.3-54-g00ecf