From 5c115a212e40ff177c734351601f9babe29419ce Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 22:05:25 +0100 Subject: Added EMA to TI --- train_ti.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 98385dd..dc36e42 100644 --- a/train_ti.py +++ b/train_ti.py @@ -2,6 +2,7 @@ import argparse import math import datetime import logging +import copy from pathlib import Path from functools import partial @@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer @@ -266,6 +267,27 @@ def parse_args(): default=None, help="Minimum learning rate in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=6/7 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase): unet, tokenizer, text_encoder, + ema_embeddings, scheduler, placeholder_token, new_ids, @@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase): self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings self.scheduler = scheduler self.placeholder_token = placeholder_token self.new_ids = new_ids @@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + if self.ema_embeddings is not None: + orig_weights = text_encoder.text_model.embeddings.temp_token_embedding + ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) + self.ema_embeddings.copy_to(ema_weights.parameters()) + text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + for (token, ids) in zip(self.placeholder_token, self.new_ids): text_encoder.text_model.embeddings.save_embed( ids, checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) + if self.ema_embeddings is not None: + text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + del text_encoder @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + if self.ema_embeddings is not None: + orig_weights = text_encoder.text_model.embeddings.temp_token_embedding + ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) + self.ema_embeddings.copy_to(ema_weights.parameters()) + text_encoder.text_model.embeddings.temp_token_embedding = ema_weights + orig_dtype = text_encoder.dtype text_encoder.to(dtype=self.weight_dtype) @@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase): text_encoder.to(dtype=orig_dtype) + if self.ema_embeddings is not None: + text_encoder.text_model.embeddings.temp_token_embedding = orig_weights + del text_encoder del pipeline @@ -567,6 +610,7 @@ def main(): text_encoder.gradient_checkpointing_enable() embeddings = patch_managed_embeddings(text_encoder) + ema_embeddings = None if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) @@ -592,6 +636,14 @@ def main(): print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") + if args.use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) + vae.requires_grad_(False) unet.requires_grad_(False) @@ -788,6 +840,7 @@ def main(): # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) + ema_embeddings.to(accelerator.device) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -883,6 +936,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + ema_embeddings=ema_embeddings, scheduler=checkpoint_scheduler, placeholder_token=args.placeholder_token, new_ids=new_ids, @@ -935,6 +989,9 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + local_progress_bar.update(1) global_progress_bar.update(1) -- cgit v1.2.3-54-g00ecf