From 300deaa789a0321f32d5e7f04d9860eaa258110e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 19:22:22 +0200 Subject: Add Textual Inversion with class dataset (a la Dreambooth) --- data/dreambooth/csv.py | 11 +- data/dreambooth/prompt.py | 18 - dreambooth.py | 25 +- textual_dreambooth.py | 948 ++++++++++++++++++++++++++++++++++++++++++++++ textual_inversion.py | 13 +- 5 files changed, 968 insertions(+), 47 deletions(-) delete mode 100644 data/dreambooth/prompt.py create mode 100644 textual_dreambooth.py diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 9075979..abd329d 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule): tokenizer, instance_identifier, class_identifier=None, + class_subdir="db_cls", size=512, repeats=100, interpolation="bicubic", @@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath("db_cls") + self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.tokenizer = tokenizer @@ -140,11 +141,9 @@ class CSVDataset(Dataset): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - instance_prompt = prompt.format(self.instance_identifier) - example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - instance_prompt, + prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, @@ -155,11 +154,9 @@ class CSVDataset(Dataset): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") - class_prompt = prompt.format(self.class_identifier) - example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - class_prompt, + prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py deleted file mode 100644 index b3a83ce..0000000 --- a/data/dreambooth/prompt.py +++ /dev/null @@ -1,18 +0,0 @@ -from torch.utils.data import Dataset - - -class PromptDataset(Dataset): - def __init__(self, prompt, nprompt, num_samples): - self.prompt = prompt - self.nprompt = nprompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["nprompt"] = self.nprompt - example["index"] = index - return example diff --git a/dreambooth.py b/dreambooth.py index aedf25c..0c5c42a 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json from data.dreambooth.csv import CSVDataModule -from data.dreambooth.prompt import PromptDataset logger = get_logger(__name__) @@ -122,7 +121,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=3e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -219,16 +218,9 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=40, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) parser.add_argument( "--prior_loss_weight", type=float, @@ -311,7 +303,7 @@ class Checkpointer: self.output_dir = output_dir self.instance_identifier = instance_identifier self.sample_image_size = sample_image_size - self.seed = seed + self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @@ -406,6 +398,8 @@ class Checkpointer: del unwrapped del scheduler del pipeline + del generator + del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -523,11 +517,13 @@ def main(): tokenizer=tokenizer, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, + class_subdir="db_cls", size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, - collate_fn=collate_fn) + collate_fn=collate_fn + ) datamodule.prepare_data() datamodule.setup() @@ -587,7 +583,7 @@ def main(): sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed or torch.random.seed() + seed=args.seed ) # Scheduler and math around the number of training steps. @@ -699,8 +695,7 @@ def main(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss diff --git a/textual_dreambooth.py b/textual_dreambooth.py new file mode 100644 index 0000000..a46953d --- /dev/null +++ b/textual_dreambooth.py @@ -0,0 +1,948 @@ +import argparse +import itertools +import math +import os +import datetime +import logging +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from schedulers.scheduling_euler_a import EulerAScheduler +from diffusers.optimization import get_scheduler +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from slugify import slugify +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +import json +import os + +from data.dreambooth.csv import CSVDataModule + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A CSV file containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", + type=str, + default=None, + help="A token to use as initializer word." + ) + parser.add_argument( + "--num_vec_per_token", + type=int, + default=1, + help=( + "The number of vectors used to represent the placeholder token. The higher the number, the better the" + " result at the cost of editability. This can be fixed by prompt editing." + ), + ) + parser.add_argument( + "--initialize_rest_random", + action="store_true", + help="Initialize rest of the placeholder tokens with random." + ) + parser.add_argument( + "--use_class_images", + action="store_true", + default=True, + help="Include class images in the loss calculation a la Dreambooth.", + ) + parser.add_argument( + "--repeats", + type=int, + default=100, + help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="output/text-inversion", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank" + ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=500, + help="How often to save a checkpoint and sample image", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=512, + help="Size of sample images", + ) + parser.add_argument( + "--sample_batches", + type=int, + default=1, + help="Number of sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=30, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--resume_from", + type=str, + default=None, + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + ) + parser.add_argument( + "--resume_checkpoint", + type=str, + default=None, + help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." + ) + + args = parser.parse_args() + if args.resume_from is not None: + with open(f"{args.resume_from}/resume.json", 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + elif args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if args.placeholder_token is None: + raise ValueError("You must specify --placeholder_token") + + if args.initializer_token is None: + raise ValueError("You must specify --initializer_token") + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def save_resume_file(basepath, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(f"{basepath}/resume.json", "w") as f: + json.dump(info, f, indent=4) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + +def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder): + assert args.num_vec_per_token >= len(token_ids) + placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)] + + for placeholder_token in placeholder_tokens: + num_added_tokens = tokenizer.add_tokens(placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + placeholder_token = " ".join(placeholder_tokens) + placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) + + print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}") + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + + if args.initialize_rest_random: + # The idea is that the placeholder tokens form adjectives as in x x x white dog. + for i, placeholder_token_id in enumerate(placeholder_token_ids): + if len(placeholder_token_ids) - i < len(token_ids): + token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] + else: + token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id]) + else: + for i, placeholder_token_id in enumerate(placeholder_token_ids): + token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] + + return placeholder_token, placeholder_token_ids + + +class Checkpointer: + def __init__( + self, + datamodule, + accelerator, + vae, + unet, + tokenizer, + placeholder_token, + placeholder_token_ids, + output_dir, + sample_image_size, + sample_batches, + sample_batch_size, + seed + ): + self.datamodule = datamodule + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.placeholder_token = placeholder_token + self.placeholder_token_ids = placeholder_token_ids + self.output_dir = output_dir + self.sample_image_size = sample_image_size + self.seed = seed or torch.random.seed() + self.sample_batches = sample_batches + self.sample_batch_size = sample_batch_size + + @torch.no_grad() + def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): + print("Saving checkpoint for step %d..." % step) + + if path is None: + checkpoints_path = f"{self.output_dir}/checkpoints" + os.makedirs(checkpoints_path, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids] + learned_embeds_dict = {} + for i, placeholder_token in enumerate(self.placeholder_token.split(" ")): + learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu() + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + if path is not None: + torch.save(learned_embeds_dict, path) + else: + torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") + torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") + + del unwrapped + del learned_embeds + + @torch.no_grad() + def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): + samples_path = Path(self.output_dir).joinpath("samples") + + unwrapped = self.accelerator.unwrap_model(text_encoder) + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + # Save a sample image + pipeline = VlpnStableDiffusion( + text_encoder=unwrapped, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=scheduler, + ).to(self.accelerator.device) + pipeline.enable_attention_slicing() + + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() + + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=generator, + ) + + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt.format(self.placeholder_token) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + latents=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path) + + del all_samples + del image_grid + + del unwrapped + del scheduler + del pipeline + del generator + del stable_latents + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main(): + args = parse_args() + + global_step_offset = 0 + if args.resume_from is not None: + basepath = Path(args.resume_from) + print("Resuming state from %s" % args.resume_from) + with open(basepath.joinpath("resume.json"), 'r') as f: + state = json.load(f) + global_step_offset = state["args"].get("global_step", 0) + + print("We've trained %d steps so far" % global_step_offset) + else: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) + basepath.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path + '/tokenizer' + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path + '/text_encoder', + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path + '/vae', + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path + '/unet', + ) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + slice_size = unet.config.attention_head_dim // 2 + unet.set_attention_slice(slice_size) + + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating. + placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token( + args, token_ids, tokenizer, text_encoder) + + # if args.resume_checkpoint is not None: + # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ + # args.placeholder_token] + # else: + # token_embeds[placeholder_token_id] = initializer_token_embeddings + + # Freeze vae and unet + freeze_params(vae.parameters()) + freeze_params(unet.parameters()) + # Freeze all parameters except for the token embeddings in text encoder + params_to_freeze = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + ) + freeze_params(params_to_freeze) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Initialize the optimizer + optimizer = optimizer_class( + text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000 + ) + + def collate_fn(examples): + prompts = [example["prompts"] for example in examples] + nprompts = [example["nprompts"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if args.use_class_images and "class_prompt_ids" in examples[0]: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + + batch = { + "prompts": prompts, + "nprompts": nprompts, + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + datamodule = CSVDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + instance_identifier=placeholder_token, + class_identifier=args.initializer_token if args.use_class_images else None, + class_subdir="ti_cls", + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + valid_set_size=args.sample_batch_size*args.sample_batches, + collate_fn=collate_fn + ) + + datamodule.prepare_data() + datamodule.setup() + + if args.use_class_images: + missing_data = [item for item in datamodule.data if not item[1].exists()] + + if len(missing_data) != 0: + batched_data = [missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size)] + + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.enable_attention_slicing() + + for batch in batched_data: + image_name = [p[1] for p in batch] + prompt = [p[2].format(args.initializer_token) for p in batch] + nprompt = [p[3] for p in batch] + + with accelerator.autocast(): + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + num_inference_steps=args.sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() + + checkpointer = Checkpointer( + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + placeholder_token=args.placeholder_token, + placeholder_token_ids=placeholder_token_ids, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + # Move vae and unet to device + vae.to(accelerator.device) + unet.to(accelerator.device) + + # Keep vae and unet in eval mode as we don't train these + vae.eval() + unet.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + val_steps = num_val_steps_per_epoch * num_epochs + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + + global_step = 0 + min_val_loss = np.inf + + if accelerator.is_main_process: + checkpointer.save_samples( + 0, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process) + local_progress_bar.set_description("Batch X out of Y") + + global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar.set_description("Total progress") + + try: + for epoch in range(num_epochs): + local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") + local_progress_bar.reset() + + text_encoder.train() + train_loss = 0.0 + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.use_class_images: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + accelerator.backward(loss) + + # Zero out the gradients for all token embeddings except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if accelerator.num_processes > 1: + grads = text_encoder.module.get_input_embeddings().weight.grad + else: + grads = text_encoder.get_input_embeddings().weight.grad + # Get the index for tokens that we want to zero the grads for + grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0] + for i in range(1, len(placeholder_token_ids)): + grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]) + grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + loss = loss.detach().item() + train_loss += loss + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + local_progress_bar.clear() + global_progress_bar.clear() + + checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": f"{basepath}/checkpoints/last.bin" + }) + + logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + local_progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + train_loss /= len(train_dataloader) + + accelerator.wait_for_everyone() + + text_encoder.eval() + val_loss = 0.0 + + for step, batch in enumerate(val_dataloader): + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + loss = loss.detach().item() + val_loss += loss + + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + logs = {"mode": "validation", "loss": loss} + local_progress_bar.set_postfix(**logs) + + val_loss /= len(val_dataloader) + + accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if min_val_loss > val_loss: + accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) + min_val_loss = val_loss + + if accelerator.is_main_process: + checkpointer.save_samples( + global_step + global_step_offset, + text_encoder, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + print("Finished! Saving final checkpoint and resume state.") + checkpointer.checkpoint( + global_step + global_step_offset, + "end", + text_encoder, + path=f"{basepath}/learned_embeds.bin" + ) + + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": f"{basepath}/checkpoints/last.bin" + }) + + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted, saving checkpoint and resume state...") + checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) + save_resume_file(basepath, args, { + "global_step": global_step + global_step_offset, + "resume_checkpoint": f"{basepath}/checkpoints/last.bin" + }) + accelerator.end_training() + quit() + + +if __name__ == "__main__": + main() diff --git a/textual_inversion.py b/textual_inversion.py index d842288..7919ebd 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -230,7 +230,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=40, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -329,7 +329,7 @@ class Checkpointer: self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed + self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @@ -481,9 +481,9 @@ def main(): # Convert the initializer_token, placeholder_token to ids initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens - if args.vectors_per_token % len(initializer_token_ids) != 0: + if len(initializer_token_ids) > 1: raise ValueError( - f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") + f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) @@ -590,7 +590,7 @@ def main(): sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed or torch.random.seed() + seed=args.seed ) # Scheduler and math around the number of training steps. @@ -620,8 +620,7 @@ def main(): unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch -- cgit v1.2.3-70-g09d2