From 9b808b6ca102cfec0c273626a0bcadf897b7c942 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 19 Dec 2022 21:10:58 +0100 Subject: Improved dataset prompt handling, fixed --- train_ti.py | 1032 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1032 insertions(+) create mode 100644 train_ti.py (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py new file mode 100644 index 0000000..dbfe58c --- /dev/null +++ b/train_ti.py @@ -0,0 +1,1032 @@ +import argparse +import itertools +import math +import os +import datetime +import logging +import json +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from slugify import slugify + +from common import load_text_embeddings, load_text_embedding +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from pipelines.util import set_use_memory_efficient_attention_xformers +from data.csv import CSVDataModule, CSVDataItem +from training.optimization import get_one_cycle_schedule +from models.clip.prompt import PromptProcessor + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A CSV file containing the training data." + ) + parser.add_argument( + "--train_data_template", + type=str, + default="template", + ) + parser.add_argument( + "--instance_identifier", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--placeholder_token", + type=str, + nargs='*', + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--num_class_images", + type=int, + default=400, + help="How many class images to generate." + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="How many times to repeat the training data." + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/text-inversion", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) + parser.add_argument( + "--mode", + type=str, + default=None, + help="A mode to filter the dataset.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=768, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0, + help="Tag dropout probability.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100 + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="one_cycle", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup", "one_cycle"]' + ), + ) + parser.add_argument( + "--lr_warmup_epochs", + type=int, + default=10, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_cycles", + type=int, + default=None, + help="Number of restart cycles in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=5, + help="How often to save a checkpoint and sample image (in epochs)", + ) + parser.add_argument( + "--sample_frequency", + type=int, + default=1, + help="How often to save a checkpoint and sample image (in epochs)", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=768, + help="Size of sample images", + ) + parser.add_argument( + "--sample_batches", + type=int, + default=1, + help="Number of sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=15, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) + parser.add_argument( + "--resume_from", + type=str, + default=None, + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + ) + parser.add_argument( + "--global_step", + type=int, + default=0, + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script." + ) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] + + if len(args.initializer_token) == 0: + raise ValueError("You must specify --initializer_token") + + if isinstance(args.placeholder_token, str): + args.placeholder_token = [args.placeholder_token] + + if len(args.placeholder_token) == 0: + args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + + if len(args.placeholder_token) != len(args.initializer_token): + raise ValueError("You must specify --placeholder_token") + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + +class Checkpointer: + def __init__( + self, + datamodule, + accelerator, + vae, + unet, + tokenizer, + text_encoder, + scheduler, + instance_identifier, + placeholder_token, + placeholder_token_id, + output_dir: Path, + sample_image_size, + sample_batches, + sample_batch_size, + seed + ): + self.datamodule = datamodule + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.scheduler = scheduler + self.instance_identifier = instance_identifier + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + self.output_dir = output_dir + self.sample_image_size = sample_image_size + self.seed = seed or torch.random.seed() + self.sample_batches = sample_batches + self.sample_batch_size = sample_batch_size + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): + # Save a checkpoint + learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + + del text_encoder + del learned_embeds + + @torch.no_grad() + def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): + samples_path = Path(self.output_dir).joinpath("samples") + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + # Save a sample image + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() + + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=generator, + ) + + with torch.autocast("cuda"), torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + + for i in range(self.sample_batches): + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + image=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + ).images + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path, quality=85) + + del all_samples + del image_grid + + del text_encoder + del pipeline + del generator + del stable_latents + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main(): + args = parse_args() + + instance_identifier = args.instance_identifier + + if len(args.placeholder_token) != 0: + instance_identifier = instance_identifier.format(args.placeholder_token[0]) + + global_step_offset = args.global_step + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + + args.seed = args.seed or (torch.random.seed() >> 32) + set_seed(args.seed) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder='scheduler') + + vae.enable_slicing() + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = torch.stack([ + torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + for token in args.initializer_token + ]) + + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + print(f"Added {num_added_tokens} new tokens.") + + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + + if args.resume_from is not None: + resumepath = Path(args.resume_from).joinpath("checkpoints") + + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) + + original_token_embeds = token_embeds.clone().to(accelerator.device) + + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): + token_embeds[token_id] = embeddings + + index_fixed_tokens = torch.arange(len(tokenizer)) + index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + + # Freeze vae and unet + freeze_params(vae.parameters()) + freeze_params(unet.parameters()) + # Freeze all parameters except for the token embeddings in text encoder + freeze_params(itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + )) + + prompt_processor = PromptProcessor(tokenizer, text_encoder) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Initialize the optimizer + optimizer = optimizer_class( + text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + def keyword_filter(item: CSVDataItem): + return any(keyword in item.prompt for keyword in args.placeholder_token) + + def collate_fn(examples): + prompts = [example["prompts"] for example in examples] + nprompts = [example["nprompts"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + + inputs = prompt_processor.unify_input_ids(input_ids) + + batch = { + "prompts": prompts, + "nprompts": nprompts, + "input_ids": inputs.input_ids, + "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, + } + return batch + + datamodule = CSVDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + prompt_processor=prompt_processor, + instance_identifier=args.instance_identifier, + class_identifier=args.class_identifier, + class_subdir="cls", + num_class_images=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + mode=args.mode, + dropout=args.tag_dropout, + center_crop=args.center_crop, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + num_workers=args.dataloader_num_workers, + filter=keyword_filter, + collate_fn=collate_fn + ) + + datamodule.prepare_data() + datamodule.setup() + + if args.num_class_images != 0: + missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] + + if len(missing_data) != 0: + batched_data = [ + missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=checkpoint_scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.autocast("cuda"), torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + num_inference_steps=args.sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps + + if args.lr_scheduler == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + elif args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_cycles or math.ceil(math.sqrt( + ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + # Move vae and unet to device + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + + # Keep vae and unet in eval mode as we don't train these + vae.eval() + unet.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + val_steps = num_val_steps_per_epoch * num_epochs + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + config = vars(args).copy() + config["initializer_token"] = " ".join(config["initializer_token"]) + config["placeholder_token"] = " ".join(config["placeholder_token"]) + accelerator.init_trackers("textual_inversion", config=config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + + global_step = 0 + min_val_loss = np.inf + + checkpointer = Checkpointer( + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=checkpoint_scheduler, + instance_identifier=args.instance_identifier, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + + if accelerator.is_main_process: + checkpointer.save_samples( + 0, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + local_progress_bar.set_description("Epoch X / Y") + + global_progress_bar = tqdm( + range(args.max_train_steps + val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) + global_progress_bar.set_description("Total progress") + + try: + for epoch in range(num_epochs): + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") + local_progress_bar.reset() + + text_encoder.train() + train_loss = 0.0 + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + text_encoder.get_input_embeddings( + ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] + + loss = loss.detach().item() + train_loss += loss + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + + accelerator.log(logs, step=global_step) + + local_progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + train_loss /= len(train_dataloader) + + accelerator.wait_for_everyone() + + text_encoder.eval() + val_loss = 0.0 + + with torch.inference_mode(): + for step, batch in enumerate(val_dataloader): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) + + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + loss = loss.detach().item() + val_loss += loss + + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + logs = {"val/loss": loss} + local_progress_bar.set_postfix(**logs) + + val_loss /= len(val_dataloader) + + accelerator.log({"val/loss": val_loss}, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if accelerator.is_main_process: + if min_val_loss > val_loss: + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") + min_val_loss = val_loss + + if (epoch + 1) % args.checkpoint_frequency == 0: + checkpointer.checkpoint(global_step + global_step_offset, "training") + save_args(basepath, args, { + "global_step": global_step + global_step_offset + }) + + if (epoch + 1) % args.sample_frequency == 0: + checkpointer.save_samples( + global_step + global_step_offset, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + print("Finished! Saving final checkpoint and resume state.") + checkpointer.checkpoint(global_step + global_step_offset, "end") + save_args(basepath, args, { + "global_step": global_step + global_step_offset + }) + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted, saving checkpoint and resume state...") + checkpointer.checkpoint(global_step + global_step_offset, "end") + save_args(basepath, args, { + "global_step": global_step + global_step_offset + }) + accelerator.end_training() + quit() + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf