From 68540b27849564994d921968a36faa9b997e626d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Dec 2022 09:17:25 +0100 Subject: Moved common training code into separate module --- train_dreambooth.py | 126 +++++----------------------------- train_ti.py | 175 +++++++++++++++-------------------------------- training/optimization.py | 2 +- training/util.py | 131 +++++++++++++++++++++++++++++++++++ 4 files changed, 203 insertions(+), 231 deletions(-) create mode 100644 training/util.py diff --git a/train_dreambooth.py b/train_dreambooth.py index 0f8fece..9749c62 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel -from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -25,6 +24,7 @@ from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule +from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -385,41 +385,7 @@ def parse_args(): return args -def save_args(basepath: Path, args, extra={}): - info = {"args": vars(args)} - info["args"].update(extra) - with open(basepath.joinpath("args.json"), "w") as f: - json.dump(info, f, indent=4) - - -def freeze_params(params): - for param in params: - param.requires_grad = False - - -def make_grid(images, rows, cols): - w, h = images[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - for i, image in enumerate(images): - grid.paste(image, box=(i % cols*w, i//cols*h)) - return grid - - -class AverageMeter: - def __init__(self, name=None): - self.name = name - self.reset() - - def reset(self): - self.sum = self.count = self.avg = 0 - - def update(self, val, n=1): - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -class Checkpointer: +class Checkpointer(CheckpointerBase): def __init__( self, datamodule, @@ -437,9 +403,20 @@ class Checkpointer: sample_image_size, sample_batches, sample_batch_size, - seed + seed, ): - self.datamodule = datamodule + super().__init__( + datamodule=datamodule, + output_dir=output_dir, + instance_identifier=instance_identifier, + placeholder_token=placeholder_token, + placeholder_token_id=placeholder_token_id, + sample_image_size=sample_image_size, + seed=seed or torch.random.seed(), + sample_batches=sample_batches, + sample_batch_size=sample_batch_size + ) + self.accelerator = accelerator self.vae = vae self.unet = unet @@ -447,14 +424,6 @@ class Checkpointer: self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler - self.output_dir = output_dir - self.instance_identifier = instance_identifier - self.placeholder_token = placeholder_token - self.placeholder_token_id = placeholder_token_id - self.sample_image_size = sample_image_size - self.seed = seed or torch.random.seed() - self.sample_batches = sample_batches - self.sample_batch_size = sample_batch_size @torch.no_grad() def save_model(self): @@ -481,8 +450,6 @@ class Checkpointer: @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - samples_path = Path(self.output_dir).joinpath("samples") - unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) @@ -495,72 +462,11 @@ class Checkpointer: ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - train_data = self.datamodule.train_dataloader() - val_data = self.datamodule.val_dataloader() - - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), - device=pipeline.device, - generator=generator, - ) - - with torch.autocast("cuda"), torch.inference_mode(): - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] - prompts = [ - prompt.format(identifier=self.instance_identifier) - for batch in batches - for prompt in batch["prompts"] - ] - nprompts = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ] - - for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - image=latents[:len(prompt)] if latents is not None else None, - generator=generator if latents is not None else None, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) del unet del text_encoder del pipeline - del generator - del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/train_ti.py b/train_ti.py index 9616db6..198cf37 100644 --- a/train_ti.py +++ b/train_ti.py @@ -7,7 +7,6 @@ import logging import json from pathlib import Path -import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -17,7 +16,6 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup -from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -26,6 +24,7 @@ from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule +from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -138,7 +137,7 @@ def parse_args(): parser.add_argument( "--tag_dropout", type=float, - default=0, + default=0.1, help="Tag dropout probability.", ) parser.add_argument( @@ -355,27 +354,7 @@ def parse_args(): return args -def freeze_params(params): - for param in params: - param.requires_grad = False - - -def save_args(basepath: Path, args, extra={}): - info = {"args": vars(args)} - info["args"].update(extra) - with open(basepath.joinpath("args.json"), "w") as f: - json.dump(info, f, indent=4) - - -def make_grid(images, rows, cols): - w, h = images[0].size - grid = Image.new('RGB', size=(cols*w, rows*h)) - for i, image in enumerate(images): - grid.paste(image, box=(i % cols*w, i//cols*h)) - return grid - - -class Checkpointer: +class Checkpointer(CheckpointerBase): def __init__( self, datamodule, @@ -394,21 +373,24 @@ class Checkpointer: sample_batch_size, seed ): - self.datamodule = datamodule + super().__init__( + datamodule=datamodule, + output_dir=output_dir, + instance_identifier=instance_identifier, + placeholder_token=placeholder_token, + placeholder_token_id=placeholder_token_id, + sample_image_size=sample_image_size, + seed=seed or torch.random.seed(), + sample_batches=sample_batches, + sample_batch_size=sample_batch_size + ) + self.accelerator = accelerator self.vae = vae self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler - self.instance_identifier = instance_identifier - self.placeholder_token = placeholder_token - self.placeholder_token_id = placeholder_token_id - self.output_dir = output_dir - self.sample_image_size = sample_image_size - self.seed = seed or torch.random.seed() - self.sample_batches = sample_batches - self.sample_batch_size = sample_batch_size @torch.no_grad() def checkpoint(self, step, postfix): @@ -431,9 +413,7 @@ class Checkpointer: del learned_embeds @torch.no_grad() - def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): - samples_path = Path(self.output_dir).joinpath("samples") - + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) # Save a sample image @@ -446,71 +426,10 @@ class Checkpointer: ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) - train_data = self.datamodule.train_dataloader() - val_data = self.datamodule.val_dataloader() - - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=generator, - ) - - with torch.autocast("cuda"), torch.inference_mode(): - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] - prompts = [ - prompt.format(identifier=self.instance_identifier) - for batch in batches - for prompt in batch["prompts"] - ] - nprompts = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ] - - for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - image=latents[:len(prompt)] if latents is not None else None, - generator=generator if latents is not None else None, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) del text_encoder del pipeline - del generator - del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -814,7 +733,14 @@ def main(): # Only show the progress bar once on each machine. global_step = 0 - min_val_loss = np.inf + + avg_loss = AverageMeter() + avg_acc = AverageMeter() + + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + + max_acc_val = 0.0 checkpointer = Checkpointer( datamodule=datamodule, @@ -835,9 +761,7 @@ def main(): ) if accelerator.is_main_process: - checkpointer.save_samples( - 0, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(global_step_offset, args.sample_steps) local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), @@ -910,6 +834,8 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + acc = (model_pred == latents).float().mean() + accelerator.backward(loss) optimizer.step() @@ -922,8 +848,8 @@ def main(): text_encoder.get_input_embeddings( ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] - loss = loss.detach().item() - train_loss += loss + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -932,7 +858,13 @@ def main(): global_step += 1 - logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } accelerator.log(logs, step=global_step) @@ -941,12 +873,9 @@ def main(): if global_step >= args.max_train_steps: break - train_loss /= len(train_dataloader) - accelerator.wait_for_everyone() text_encoder.eval() - val_loss = 0.0 with torch.inference_mode(): for step, batch in enumerate(val_dataloader): @@ -976,29 +905,37 @@ def main(): loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - loss = loss.detach().item() - val_loss += loss + acc = (model_pred == latents).float().mean() + + avg_loss_val.update(loss.detach_(), bsz) + avg_acc_val.update(acc.detach_(), bsz) if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"val/loss": loss} + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } local_progress_bar.set_postfix(**logs) - val_loss /= len(val_dataloader) - - accelerator.log({"val/loss": val_loss}, step=global_step) + accelerator.log({ + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() if accelerator.is_main_process: - if min_val_loss > val_loss: + if avg_acc_val.avg.item() > max_acc_val: accelerator.print( - f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone") - min_val_loss = val_loss + max_acc_val = avg_acc_val.avg.item() if (epoch + 1) % args.checkpoint_frequency == 0: checkpointer.checkpoint(global_step + global_step_offset, "training") @@ -1007,9 +944,7 @@ def main(): }) if (epoch + 1) % args.sample_frequency == 0: - checkpointer.save_samples( - global_step + global_step_offset, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: diff --git a/training/optimization.py b/training/optimization.py index 0e603fa..c501ed9 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,7 +6,7 @@ from diffusers.utils import logging logger = logging.get_logger(__name__) -def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): +def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. diff --git a/training/util.py b/training/util.py new file mode 100644 index 0000000..e8d22ae --- /dev/null +++ b/training/util.py @@ -0,0 +1,131 @@ +from pathlib import Path +import json + +import torch +from PIL import Image + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + +class AverageMeter: + def __init__(self, name=None): + self.name = name + self.reset() + + def reset(self): + self.sum = self.count = self.avg = 0 + + def update(self, val, n=1): + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class CheckpointerBase: + def __init__( + self, + datamodule, + output_dir: Path, + instance_identifier, + placeholder_token, + placeholder_token_id, + sample_image_size, + sample_batches, + sample_batch_size, + seed + ): + self.datamodule = datamodule + self.output_dir = output_dir + self.instance_identifier = instance_identifier + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + self.sample_image_size = sample_image_size + self.seed = seed or torch.random.seed() + self.sample_batches = sample_batches + self.sample_batch_size = sample_batch_size + + @torch.no_grad() + def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + samples_path = Path(self.output_dir).joinpath("samples") + + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() + + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), + device=pipeline.device, + generator=generator, + ) + + with torch.autocast("cuda"), torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + + for i in range(self.sample_batches): + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + image=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + ).images + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path, quality=85) + + del all_samples + del image_grid + + del generator + del stable_latents -- cgit v1.2.3-70-g09d2