From 6c64f769043c8212b1a5778e857af691a828798d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 10:19:38 +0100 Subject: Various cleanups --- train_ti.py | 86 +++++++++++++++---------------------------------------------- 1 file changed, 21 insertions(+), 65 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 8ada98c..5df6850 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,9 +3,9 @@ import math import datetime import logging from pathlib import Path +from functools import partial import torch -import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator @@ -18,9 +18,10 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel from slugify import slugify -from common import load_config, load_embeddings_from_dir +from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem +from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args @@ -570,8 +571,8 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") # Convert the initializer_token, placeholder_token to ids initializer_token_ids = [ @@ -579,13 +580,15 @@ def main(): for token in args.initializer_token ] - new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) + new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) embeddings.resize(len(tokenizer)) - for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): - embeddings.add_embed(new_token.ids, init_ids) + init_ratios = [ + embeddings.add_embed(new_id, init_ids) + for (new_id, init_ids) in zip(new_ids, initializer_token_ids) + ] - print(f"Added {len(new_tokens)} new tokens.") + print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") vae.requires_grad_(False) unet.requires_grad_(False) @@ -807,63 +810,16 @@ def main(): def on_eval(): tokenizer.eval() - def loop(step: int, batch, eval: bool = False): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - generator=timesteps_gen, - device=latents.device, - ) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) - encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) - - # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - - if args.num_class_images != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - acc = (model_pred == target).float().mean() - - return loss, acc, bsz + loop = partial( + run_model, + vae=vae, + noise_scheduler=noise_scheduler, + unet=unet, + prompt_processor=prompt_processor, + num_class_images=args.num_class_images, + prior_loss_weight=args.prior_loss_weight, + seed=args.seed, + ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. -- cgit v1.2.3-54-g00ecf