From 9b808b6ca102cfec0c273626a0bcadf897b7c942 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 19 Dec 2022 21:10:58 +0100
Subject: Improved dataset prompt handling, fixed

---
 .gitignore           |    5 +-
 data/csv.py          |   41 +-
 dreambooth.py        | 1133 --------------------------------------------------
 textual_inversion.py | 1034 ---------------------------------------------
 train_dreambooth.py  | 1133 ++++++++++++++++++++++++++++++++++++++++++++++++++
 train_ti.py          | 1032 +++++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 2190 insertions(+), 2188 deletions(-)
 delete mode 100644 dreambooth.py
 delete mode 100644 textual_inversion.py
 create mode 100644 train_dreambooth.py
 create mode 100644 train_ti.py

diff --git a/.gitignore b/.gitignore
index d84b4dd..fba4926 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,8 +160,7 @@ cython_debug/
 #.idea/
 
 output/
-conf/
-embeddings_ti/
-embeddings_ag/
+conf*/
+embeddings*/
 v1-inference.yaml*
 *.old
diff --git a/data/csv.py b/data/csv.py
index 053457b..6525e45 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -16,26 +16,29 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
     return {"content": prompt} if isinstance(prompt, str) else prompt
 
 
-def shuffle_prompt(prompt: str, dropout: float = 0):
-    def handle_block(block: str):
-        words = block.split(", ")
-        words = [w for w in words if w != ""]
-        if dropout != 0:
-            words = [w for w in words if np.random.random() > dropout]
-        np.random.shuffle(words)
-        return ", ".join(words)
-
-    prompt = prompt.split(". ")
-    prompt = [handle_block(b) for b in prompt if b != ""]
+def keywords_to_prompt(prompt: list[str], dropout: float = 0) -> str:
+    if dropout != 0:
+        prompt = [keyword for keyword in prompt if np.random.random() > dropout]
     np.random.shuffle(prompt)
-    prompt = ". ".join(prompt)
-    return prompt
+    return ", ".join(prompt)
+
+
+def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]:
+    def expand_keyword(keyword: str) -> list[str]:
+        return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
+
+    return [
+        kw
+        for keyword in prompt.split(", ")
+        for kw in expand_keyword(keyword)
+        if keyword != ""
+    ]
 
 
 class CSVDataItem(NamedTuple):
     instance_image_path: Path
     class_image_path: Path
-    prompt: str
+    prompt: list[str]
     nprompt: str
 
 
@@ -91,7 +94,7 @@ class CSVDataModule(pl.LightningDataModule):
         self.num_workers = num_workers
         self.batch_size = batch_size
 
-    def prepare_items(self, template, data) -> list[CSVDataItem]:
+    def prepare_items(self, template, expansions, data) -> list[CSVDataItem]:
         image = template["image"] if "image" in template else "{}"
         prompt = template["prompt"] if "prompt" in template else "{content}"
         nprompt = template["nprompt"] if "nprompt" in template else "{content}"
@@ -100,7 +103,8 @@ class CSVDataModule(pl.LightningDataModule):
             CSVDataItem(
                 self.data_root.joinpath(image.format(item["image"])),
                 None,
-                prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
+                prompt_to_keywords(prompt.format(
+                    **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions),
                 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
             )
             for item in data
@@ -130,6 +134,7 @@ class CSVDataModule(pl.LightningDataModule):
         with open(self.data_file, 'rt') as f:
             metadata = json.load(f)
         template = metadata[self.template_key] if self.template_key in metadata else {}
+        expansions = metadata["expansions"] if "expansions" in metadata else {}
         items = metadata["items"] if "items" in metadata else []
 
         if self.mode is not None:
@@ -138,7 +143,7 @@ class CSVDataModule(pl.LightningDataModule):
                 for item in items
                 if "mode" in item and self.mode in item["mode"]
             ]
-        items = self.prepare_items(template, items)
+        items = self.prepare_items(template, expansions, items)
         items = self.filter_items(items)
 
         num_images = len(items)
@@ -255,7 +260,7 @@ class CSVDataset(Dataset):
 
         example = {}
 
-        example["prompts"] = shuffle_prompt(unprocessed_example["prompts"])
+        example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout)
         example["nprompts"] = unprocessed_example["nprompts"]
 
         example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
diff --git a/dreambooth.py b/dreambooth.py
deleted file mode 100644
index 3eecf9c..0000000
--- a/dreambooth.py
+++ /dev/null
@@ -1,1133 +0,0 @@
-import argparse
-import itertools
-import math
-import datetime
-import logging
-import json
-from pathlib import Path
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import LoggerType, set_seed
-from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
-from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
-from diffusers.training_utils import EMAModel
-from PIL import Image
-from tqdm.auto import tqdm
-from transformers import CLIPTextModel, CLIPTokenizer
-from slugify import slugify
-
-from common import load_text_embeddings
-from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
-from pipelines.util import set_use_memory_efficient_attention_xformers
-from data.csv import CSVDataModule
-from training.optimization import get_one_cycle_schedule
-from models.clip.prompt import PromptProcessor
-
-logger = get_logger(__name__)
-
-
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.benchmark = True
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description="Simple example of a training script."
-    )
-    parser.add_argument(
-        "--pretrained_model_name_or_path",
-        type=str,
-        default=None,
-        help="Path to pretrained model or model identifier from huggingface.co/models.",
-    )
-    parser.add_argument(
-        "--tokenizer_name",
-        type=str,
-        default=None,
-        help="Pretrained tokenizer name or path if not the same as model_name",
-    )
-    parser.add_argument(
-        "--train_data_file",
-        type=str,
-        default=None,
-        help="A folder containing the training data."
-    )
-    parser.add_argument(
-        "--train_data_template",
-        type=str,
-        default="template",
-    )
-    parser.add_argument(
-        "--instance_identifier",
-        type=str,
-        default=None,
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--class_identifier",
-        type=str,
-        default=None,
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--placeholder_token",
-        type=str,
-        nargs='*',
-        default=[],
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--initializer_token",
-        type=str,
-        nargs='*',
-        default=[],
-        help="A token to use as initializer word."
-    )
-    parser.add_argument(
-        "--train_text_encoder",
-        action="store_true",
-        default=True,
-        help="Whether to train the whole text encoder."
-    )
-    parser.add_argument(
-        "--train_text_encoder_epochs",
-        default=999999,
-        help="Number of epochs the text encoder will be trained."
-    )
-    parser.add_argument(
-        "--tag_dropout",
-        type=float,
-        default=0.1,
-        help="Tag dropout probability.",
-    )
-    parser.add_argument(
-        "--num_class_images",
-        type=int,
-        default=400,
-        help="How many class images to generate."
-    )
-    parser.add_argument(
-        "--repeats",
-        type=int,
-        default=1,
-        help="How many times to repeat the training data."
-    )
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        default="output/dreambooth",
-        help="The output directory where the model predictions and checkpoints will be written.",
-    )
-    parser.add_argument(
-        "--embeddings_dir",
-        type=str,
-        default=None,
-        help="The embeddings directory where Textual Inversion embeddings are stored.",
-    )
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default=None,
-        help="A mode to filter the dataset.",
-    )
-    parser.add_argument(
-        "--seed",
-        type=int,
-        default=None,
-        help="A seed for reproducible training."
-    )
-    parser.add_argument(
-        "--resolution",
-        type=int,
-        default=768,
-        help=(
-            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
-            " resolution"
-        ),
-    )
-    parser.add_argument(
-        "--center_crop",
-        action="store_true",
-        help="Whether to center crop images before resizing to resolution"
-    )
-    parser.add_argument(
-        "--dataloader_num_workers",
-        type=int,
-        default=0,
-        help=(
-            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
-            " process."
-        ),
-    )
-    parser.add_argument(
-        "--num_train_epochs",
-        type=int,
-        default=100
-    )
-    parser.add_argument(
-        "--max_train_steps",
-        type=int,
-        default=None,
-        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
-    )
-    parser.add_argument(
-        "--gradient_accumulation_steps",
-        type=int,
-        default=1,
-        help="Number of updates steps to accumulate before performing a backward/update pass.",
-    )
-    parser.add_argument(
-        "--gradient_checkpointing",
-        action="store_true",
-        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
-    )
-    parser.add_argument(
-        "--learning_rate_unet",
-        type=float,
-        default=2e-6,
-        help="Initial learning rate (after the potential warmup period) to use.",
-    )
-    parser.add_argument(
-        "--learning_rate_text",
-        type=float,
-        default=2e-6,
-        help="Initial learning rate (after the potential warmup period) to use.",
-    )
-    parser.add_argument(
-        "--scale_lr",
-        action="store_true",
-        default=True,
-        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
-    )
-    parser.add_argument(
-        "--lr_scheduler",
-        type=str,
-        default="one_cycle",
-        help=(
-            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
-            ' "constant", "constant_with_warmup", "one_cycle"]'
-        ),
-    )
-    parser.add_argument(
-        "--lr_warmup_epochs",
-        type=int,
-        default=10,
-        help="Number of steps for the warmup in the lr scheduler."
-    )
-    parser.add_argument(
-        "--lr_cycles",
-        type=int,
-        default=None,
-        help="Number of restart cycles in the lr scheduler (if supported)."
-    )
-    parser.add_argument(
-        "--use_ema",
-        action="store_true",
-        default=True,
-        help="Whether to use EMA model."
-    )
-    parser.add_argument(
-        "--ema_inv_gamma",
-        type=float,
-        default=1.0
-    )
-    parser.add_argument(
-        "--ema_power",
-        type=float,
-        default=6/7
-    )
-    parser.add_argument(
-        "--ema_max_decay",
-        type=float,
-        default=0.9999
-    )
-    parser.add_argument(
-        "--use_8bit_adam",
-        action="store_true",
-        default=True,
-        help="Whether or not to use 8-bit Adam from bitsandbytes."
-    )
-    parser.add_argument(
-        "--adam_beta1",
-        type=float,
-        default=0.9,
-        help="The beta1 parameter for the Adam optimizer."
-    )
-    parser.add_argument(
-        "--adam_beta2",
-        type=float,
-        default=0.999,
-        help="The beta2 parameter for the Adam optimizer."
-    )
-    parser.add_argument(
-        "--adam_weight_decay",
-        type=float,
-        default=1e-2,
-        help="Weight decay to use."
-    )
-    parser.add_argument(
-        "--adam_epsilon",
-        type=float,
-        default=1e-08,
-        help="Epsilon value for the Adam optimizer"
-    )
-    parser.add_argument(
-        "--mixed_precision",
-        type=str,
-        default="no",
-        choices=["no", "fp16", "bf16"],
-        help=(
-            "Whether to use mixed precision. Choose"
-            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
-            "and an Nvidia Ampere GPU."
-        ),
-    )
-    parser.add_argument(
-        "--sample_frequency",
-        type=int,
-        default=1,
-        help="How often to save a checkpoint and sample image",
-    )
-    parser.add_argument(
-        "--sample_image_size",
-        type=int,
-        default=768,
-        help="Size of sample images",
-    )
-    parser.add_argument(
-        "--sample_batches",
-        type=int,
-        default=1,
-        help="Number of sample batches to generate per checkpoint",
-    )
-    parser.add_argument(
-        "--sample_batch_size",
-        type=int,
-        default=1,
-        help="Number of samples to generate per batch",
-    )
-    parser.add_argument(
-        "--valid_set_size",
-        type=int,
-        default=None,
-        help="Number of images in the validation dataset."
-    )
-    parser.add_argument(
-        "--train_batch_size",
-        type=int,
-        default=1,
-        help="Batch size (per device) for the training dataloader."
-    )
-    parser.add_argument(
-        "--sample_steps",
-        type=int,
-        default=15,
-        help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
-    )
-    parser.add_argument(
-        "--prior_loss_weight",
-        type=float,
-        default=1.0,
-        help="The weight of prior preservation loss."
-    )
-    parser.add_argument(
-        "--max_grad_norm",
-        default=1.0,
-        type=float,
-        help="Max gradient norm."
-    )
-    parser.add_argument(
-        "--noise_timesteps",
-        type=int,
-        default=1000,
-    )
-    parser.add_argument(
-        "--config",
-        type=str,
-        default=None,
-        help="Path to a JSON configuration file containing arguments for invoking this script."
-    )
-
-    args = parser.parse_args()
-    if args.config is not None:
-        with open(args.config, 'rt') as f:
-            args = parser.parse_args(
-                namespace=argparse.Namespace(**json.load(f)["args"]))
-
-    if args.train_data_file is None:
-        raise ValueError("You must specify --train_data_file")
-
-    if args.pretrained_model_name_or_path is None:
-        raise ValueError("You must specify --pretrained_model_name_or_path")
-
-    if args.instance_identifier is None:
-        raise ValueError("You must specify --instance_identifier")
-
-    if isinstance(args.initializer_token, str):
-        args.initializer_token = [args.initializer_token]
-
-    if isinstance(args.placeholder_token, str):
-        args.placeholder_token = [args.placeholder_token]
-
-    if len(args.placeholder_token) == 0:
-        args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
-
-    if len(args.placeholder_token) != len(args.initializer_token):
-        raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
-
-    if args.output_dir is None:
-        raise ValueError("You must specify --output_dir")
-
-    return args
-
-
-def save_args(basepath: Path, args, extra={}):
-    info = {"args": vars(args)}
-    info["args"].update(extra)
-    with open(basepath.joinpath("args.json"), "w") as f:
-        json.dump(info, f, indent=4)
-
-
-def freeze_params(params):
-    for param in params:
-        param.requires_grad = False
-
-
-def make_grid(images, rows, cols):
-    w, h = images[0].size
-    grid = Image.new('RGB', size=(cols*w, rows*h))
-    for i, image in enumerate(images):
-        grid.paste(image, box=(i % cols*w, i//cols*h))
-    return grid
-
-
-class AverageMeter:
-    def __init__(self, name=None):
-        self.name = name
-        self.reset()
-
-    def reset(self):
-        self.sum = self.count = self.avg = 0
-
-    def update(self, val, n=1):
-        self.sum += val * n
-        self.count += n
-        self.avg = self.sum / self.count
-
-
-class Checkpointer:
-    def __init__(
-        self,
-        datamodule,
-        accelerator,
-        vae,
-        unet,
-        ema_unet,
-        tokenizer,
-        text_encoder,
-        scheduler,
-        output_dir: Path,
-        instance_identifier,
-        placeholder_token,
-        placeholder_token_id,
-        sample_image_size,
-        sample_batches,
-        sample_batch_size,
-        seed
-    ):
-        self.datamodule = datamodule
-        self.accelerator = accelerator
-        self.vae = vae
-        self.unet = unet
-        self.ema_unet = ema_unet
-        self.tokenizer = tokenizer
-        self.text_encoder = text_encoder
-        self.scheduler = scheduler
-        self.output_dir = output_dir
-        self.instance_identifier = instance_identifier
-        self.placeholder_token = placeholder_token
-        self.placeholder_token_id = placeholder_token_id
-        self.sample_image_size = sample_image_size
-        self.seed = seed or torch.random.seed()
-        self.sample_batches = sample_batches
-        self.sample_batch_size = sample_batch_size
-
-    @torch.no_grad()
-    def save_model(self):
-        print("Saving model...")
-
-        unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
-        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
-
-        pipeline = VlpnStableDiffusion(
-            text_encoder=text_encoder,
-            vae=self.vae,
-            unet=unet,
-            tokenizer=self.tokenizer,
-            scheduler=self.scheduler,
-        )
-        pipeline.save_pretrained(self.output_dir.joinpath("model"))
-
-        del unet
-        del text_encoder
-        del pipeline
-
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-
-    @torch.no_grad()
-    def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
-        samples_path = Path(self.output_dir).joinpath("samples")
-
-        unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
-        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
-
-        pipeline = VlpnStableDiffusion(
-            text_encoder=text_encoder,
-            vae=self.vae,
-            unet=unet,
-            tokenizer=self.tokenizer,
-            scheduler=self.scheduler,
-        ).to(self.accelerator.device)
-        pipeline.set_progress_bar_config(dynamic_ncols=True)
-
-        train_data = self.datamodule.train_dataloader()
-        val_data = self.datamodule.val_dataloader()
-
-        generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
-        stable_latents = torch.randn(
-            (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
-            device=pipeline.device,
-            generator=generator,
-        )
-
-        with torch.autocast("cuda"), torch.inference_mode():
-            for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
-                all_samples = []
-                file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
-                file_path.parent.mkdir(parents=True, exist_ok=True)
-
-                data_enum = enumerate(data)
-
-                batches = [
-                    batch
-                    for j, batch in data_enum
-                    if j * data.batch_size < self.sample_batch_size * self.sample_batches
-                ]
-                prompts = [
-                    prompt.format(identifier=self.instance_identifier)
-                    for batch in batches
-                    for prompt in batch["prompts"]
-                ]
-                nprompts = [
-                    prompt
-                    for batch in batches
-                    for prompt in batch["nprompts"]
-                ]
-
-                for i in range(self.sample_batches):
-                    prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
-                    nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
-
-                    samples = pipeline(
-                        prompt=prompt,
-                        negative_prompt=nprompt,
-                        height=self.sample_image_size,
-                        width=self.sample_image_size,
-                        image=latents[:len(prompt)] if latents is not None else None,
-                        generator=generator if latents is not None else None,
-                        guidance_scale=guidance_scale,
-                        eta=eta,
-                        num_inference_steps=num_inference_steps,
-                        output_type='pil'
-                    ).images
-
-                    all_samples += samples
-
-                    del samples
-
-                image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
-                image_grid.save(file_path, quality=85)
-
-                del all_samples
-                del image_grid
-
-        del unet
-        del text_encoder
-        del pipeline
-        del generator
-        del stable_latents
-
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-
-
-def main():
-    args = parse_args()
-
-    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
-        raise ValueError(
-            "Gradient accumulation is not supported when training the text encoder in distributed training. "
-            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
-        )
-
-    instance_identifier = args.instance_identifier
-
-    if len(args.placeholder_token) != 0:
-        instance_identifier = instance_identifier.format(args.placeholder_token[0])
-
-    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
-    basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
-    basepath.mkdir(parents=True, exist_ok=True)
-
-    accelerator = Accelerator(
-        log_with=LoggerType.TENSORBOARD,
-        logging_dir=f"{basepath}",
-        gradient_accumulation_steps=args.gradient_accumulation_steps,
-        mixed_precision=args.mixed_precision
-    )
-
-    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
-
-    args.seed = args.seed or (torch.random.seed() >> 32)
-    set_seed(args.seed)
-
-    save_args(basepath, args)
-
-    # Load the tokenizer and add the placeholder token as a additional special token
-    if args.tokenizer_name:
-        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
-    elif args.pretrained_model_name_or_path:
-        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
-
-    # Load models and create wrapper for stable diffusion
-    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
-    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
-    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
-    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
-    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
-        args.pretrained_model_name_or_path, subfolder='scheduler')
-
-    vae.enable_slicing()
-    set_use_memory_efficient_attention_xformers(unet, True)
-    set_use_memory_efficient_attention_xformers(vae, True)
-
-    if args.gradient_checkpointing:
-        unet.enable_gradient_checkpointing()
-        text_encoder.gradient_checkpointing_enable()
-
-    ema_unet = None
-    if args.use_ema:
-        ema_unet = EMAModel(
-            unet,
-            inv_gamma=args.ema_inv_gamma,
-            power=args.ema_power,
-            max_value=args.ema_max_decay,
-            device=accelerator.device
-        )
-
-    # Freeze text_encoder and vae
-    vae.requires_grad_(False)
-
-    if args.embeddings_dir is not None:
-        embeddings_dir = Path(args.embeddings_dir)
-        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
-            raise ValueError("--embeddings_dir must point to an existing directory")
-        added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
-        print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
-
-    if len(args.placeholder_token) != 0:
-        # Convert the initializer_token, placeholder_token to ids
-        initializer_token_ids = torch.stack([
-            torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
-            for token in args.initializer_token
-        ])
-
-        num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
-        print(f"Added {num_added_tokens} new tokens.")
-
-        placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
-
-        # Resize the token embeddings as we are adding new special tokens to the tokenizer
-        text_encoder.resize_token_embeddings(len(tokenizer))
-
-        token_embeds = text_encoder.get_input_embeddings().weight.data
-        original_token_embeds = token_embeds.clone().to(accelerator.device)
-        initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
-
-        for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
-            token_embeds[token_id] = embeddings
-    else:
-        placeholder_token_id = []
-
-    if args.train_text_encoder:
-        print(f"Training entire text encoder.")
-    else:
-        print(f"Training added text embeddings")
-
-        freeze_params(itertools.chain(
-            text_encoder.text_model.encoder.parameters(),
-            text_encoder.text_model.final_layer_norm.parameters(),
-            text_encoder.text_model.embeddings.position_embedding.parameters(),
-        ))
-
-        index_fixed_tokens = torch.arange(len(tokenizer))
-        index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
-
-    prompt_processor = PromptProcessor(tokenizer, text_encoder)
-
-    if args.scale_lr:
-        args.learning_rate_unet = (
-            args.learning_rate_unet * args.gradient_accumulation_steps *
-            args.train_batch_size * accelerator.num_processes
-        )
-        args.learning_rate_text = (
-            args.learning_rate_text * args.gradient_accumulation_steps *
-            args.train_batch_size * accelerator.num_processes
-        )
-
-    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
-    if args.use_8bit_adam:
-        try:
-            import bitsandbytes as bnb
-        except ImportError:
-            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
-
-        optimizer_class = bnb.optim.AdamW8bit
-    else:
-        optimizer_class = torch.optim.AdamW
-
-    if args.train_text_encoder:
-        text_encoder_params_to_optimize = text_encoder.parameters()
-    else:
-        text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
-
-    # Initialize the optimizer
-    optimizer = optimizer_class(
-        [
-            {
-                'params': unet.parameters(),
-                'lr': args.learning_rate_unet,
-            },
-            {
-                'params': text_encoder_params_to_optimize,
-                'lr': args.learning_rate_text,
-            }
-        ],
-        betas=(args.adam_beta1, args.adam_beta2),
-        weight_decay=args.adam_weight_decay,
-        eps=args.adam_epsilon,
-    )
-
-    weight_dtype = torch.float32
-    if args.mixed_precision == "fp16":
-        weight_dtype = torch.float16
-    elif args.mixed_precision == "bf16":
-        weight_dtype = torch.bfloat16
-
-    def collate_fn(examples):
-        prompts = [example["prompts"] for example in examples]
-        nprompts = [example["nprompts"] for example in examples]
-        input_ids = [example["instance_prompt_ids"] for example in examples]
-        pixel_values = [example["instance_images"] for example in examples]
-
-        # concat class and instance examples for prior preservation
-        if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
-            input_ids += [example["class_prompt_ids"] for example in examples]
-            pixel_values += [example["class_images"] for example in examples]
-
-        pixel_values = torch.stack(pixel_values)
-        pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
-
-        inputs = prompt_processor.unify_input_ids(input_ids)
-
-        batch = {
-            "prompts": prompts,
-            "nprompts": nprompts,
-            "input_ids": inputs.input_ids,
-            "pixel_values": pixel_values,
-            "attention_mask": inputs.attention_mask,
-        }
-        return batch
-
-    datamodule = CSVDataModule(
-        data_file=args.train_data_file,
-        batch_size=args.train_batch_size,
-        prompt_processor=prompt_processor,
-        instance_identifier=instance_identifier,
-        class_identifier=args.class_identifier,
-        class_subdir="cls",
-        num_class_images=args.num_class_images,
-        size=args.resolution,
-        repeats=args.repeats,
-        mode=args.mode,
-        dropout=args.tag_dropout,
-        center_crop=args.center_crop,
-        template_key=args.train_data_template,
-        valid_set_size=args.valid_set_size,
-        num_workers=args.dataloader_num_workers,
-        collate_fn=collate_fn
-    )
-
-    datamodule.prepare_data()
-    datamodule.setup()
-
-    if args.num_class_images != 0:
-        missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
-
-        if len(missing_data) != 0:
-            batched_data = [
-                missing_data[i:i+args.sample_batch_size]
-                for i in range(0, len(missing_data), args.sample_batch_size)
-            ]
-
-            pipeline = VlpnStableDiffusion(
-                text_encoder=text_encoder,
-                vae=vae,
-                unet=unet,
-                tokenizer=tokenizer,
-                scheduler=checkpoint_scheduler,
-            ).to(accelerator.device)
-            pipeline.set_progress_bar_config(dynamic_ncols=True)
-
-            with torch.autocast("cuda"), torch.inference_mode():
-                for batch in batched_data:
-                    image_name = [item.class_image_path for item in batch]
-                    prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
-                    nprompt = [item.nprompt for item in batch]
-
-                    images = pipeline(
-                        prompt=prompt,
-                        negative_prompt=nprompt,
-                        num_inference_steps=args.sample_steps
-                    ).images
-
-                    for i, image in enumerate(images):
-                        image.save(image_name[i])
-
-            del pipeline
-
-            if torch.cuda.is_available():
-                torch.cuda.empty_cache()
-
-    train_dataloader = datamodule.train_dataloader()
-    val_dataloader = datamodule.val_dataloader()
-
-    # Scheduler and math around the number of training steps.
-    overrode_max_train_steps = False
-    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
-    if args.max_train_steps is None:
-        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
-        overrode_max_train_steps = True
-    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
-    warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
-
-    if args.lr_scheduler == "one_cycle":
-        lr_scheduler = get_one_cycle_schedule(
-            optimizer=optimizer,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-        )
-    elif args.lr_scheduler == "cosine_with_restarts":
-        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
-            optimizer=optimizer,
-            num_warmup_steps=warmup_steps,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
-                ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
-        )
-    else:
-        lr_scheduler = get_scheduler(
-            args.lr_scheduler,
-            optimizer=optimizer,
-            num_warmup_steps=warmup_steps,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-        )
-
-    unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
-        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
-    )
-
-    # Move text_encoder and vae to device
-    vae.to(accelerator.device, dtype=weight_dtype)
-
-    # Keep text_encoder and vae in eval mode as we don't train these
-    vae.eval()
-
-    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
-    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
-    if overrode_max_train_steps:
-        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
-
-    num_val_steps_per_epoch = len(val_dataloader)
-    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-    val_steps = num_val_steps_per_epoch * num_epochs
-
-    # We need to initialize the trackers we use, and also store our configuration.
-    # The trackers initializes automatically on the main process.
-    if accelerator.is_main_process:
-        config = vars(args).copy()
-        config["initializer_token"] = " ".join(config["initializer_token"])
-        config["placeholder_token"] = " ".join(config["placeholder_token"])
-        accelerator.init_trackers("dreambooth", config=config)
-
-    # Train!
-    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
-
-    logger.info("***** Running training *****")
-    logger.info(f"  Num Epochs = {num_epochs}")
-    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
-    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
-    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
-    logger.info(f"  Total optimization steps = {args.max_train_steps}")
-    # Only show the progress bar once on each machine.
-
-    global_step = 0
-
-    avg_loss = AverageMeter()
-    avg_acc = AverageMeter()
-
-    avg_loss_val = AverageMeter()
-    avg_acc_val = AverageMeter()
-
-    max_acc_val = 0.0
-
-    checkpointer = Checkpointer(
-        datamodule=datamodule,
-        accelerator=accelerator,
-        vae=vae,
-        unet=unet,
-        ema_unet=ema_unet,
-        tokenizer=tokenizer,
-        text_encoder=text_encoder,
-        scheduler=checkpoint_scheduler,
-        output_dir=basepath,
-        instance_identifier=instance_identifier,
-        placeholder_token=args.placeholder_token,
-        placeholder_token_id=placeholder_token_id,
-        sample_image_size=args.sample_image_size,
-        sample_batch_size=args.sample_batch_size,
-        sample_batches=args.sample_batches,
-        seed=args.seed
-    )
-
-    if accelerator.is_main_process:
-        checkpointer.save_samples(0, args.sample_steps)
-
-    local_progress_bar = tqdm(
-        range(num_update_steps_per_epoch + num_val_steps_per_epoch),
-        disable=not accelerator.is_local_main_process,
-        dynamic_ncols=True
-    )
-    local_progress_bar.set_description("Epoch X / Y")
-
-    global_progress_bar = tqdm(
-        range(args.max_train_steps + val_steps),
-        disable=not accelerator.is_local_main_process,
-        dynamic_ncols=True
-    )
-    global_progress_bar.set_description("Total progress")
-
-    try:
-        for epoch in range(num_epochs):
-            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
-            local_progress_bar.reset()
-
-            unet.train()
-
-            if epoch < args.train_text_encoder_epochs:
-                text_encoder.train()
-            elif epoch == args.train_text_encoder_epochs:
-                freeze_params(text_encoder.parameters())
-
-            sample_checkpoint = False
-
-            for step, batch in enumerate(train_dataloader):
-                with accelerator.accumulate(unet):
-                    # Convert images to latent space
-                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
-                    latents = latents * 0.18215
-
-                    # Sample noise that we'll add to the latents
-                    noise = torch.randn_like(latents)
-                    bsz = latents.shape[0]
-                    # Sample a random timestep for each image
-                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
-                                              (bsz,), device=latents.device)
-                    timesteps = timesteps.long()
-
-                    # Add noise to the latents according to the noise magnitude at each timestep
-                    # (this is the forward diffusion process)
-                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
-                    # Get the text embedding for conditioning
-                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
-
-                    # Predict the noise residual
-                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
-                    # Get the target for loss depending on the prediction type
-                    if noise_scheduler.config.prediction_type == "epsilon":
-                        target = noise
-                    elif noise_scheduler.config.prediction_type == "v_prediction":
-                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
-                    else:
-                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
-                    if args.num_class_images != 0:
-                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
-                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
-                        target, target_prior = torch.chunk(target, 2, dim=0)
-
-                        # Compute instance loss
-                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
-
-                        # Compute prior loss
-                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
-
-                        # Add the prior loss to the instance loss.
-                        loss = loss + args.prior_loss_weight * prior_loss
-                    else:
-                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
-
-                    acc = (model_pred == latents).float().mean()
-
-                    accelerator.backward(loss)
-
-                    if accelerator.sync_gradients:
-                        params_to_clip = (
-                            itertools.chain(unet.parameters(), text_encoder.parameters())
-                            if args.train_text_encoder and epoch < args.train_text_encoder_epochs
-                            else unet.parameters()
-                        )
-                        accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
-
-                    optimizer.step()
-                    if not accelerator.optimizer_step_was_skipped:
-                        lr_scheduler.step()
-                    if args.use_ema:
-                        ema_unet.step(unet)
-                    optimizer.zero_grad(set_to_none=True)
-
-                    if not args.train_text_encoder:
-                        # Let's make sure we don't update any embedding weights besides the newly added token
-                        with torch.no_grad():
-                            text_encoder.get_input_embeddings(
-                            ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
-
-                    avg_loss.update(loss.detach_(), bsz)
-                    avg_acc.update(acc.detach_(), bsz)
-
-                # Checks if the accelerator has performed an optimization step behind the scenes
-                if accelerator.sync_gradients:
-                    local_progress_bar.update(1)
-                    global_progress_bar.update(1)
-
-                    global_step += 1
-
-                logs = {
-                    "train/loss": avg_loss.avg.item(),
-                    "train/acc": avg_acc.avg.item(),
-                    "train/cur_loss": loss.item(),
-                    "train/cur_acc": acc.item(),
-                    "lr/unet": lr_scheduler.get_last_lr()[0],
-                    "lr/text": lr_scheduler.get_last_lr()[1]
-                }
-                if args.use_ema:
-                    logs["ema_decay"] = 1 - ema_unet.decay
-
-                accelerator.log(logs, step=global_step)
-
-                local_progress_bar.set_postfix(**logs)
-
-                if global_step >= args.max_train_steps:
-                    break
-
-            accelerator.wait_for_everyone()
-
-            unet.eval()
-            text_encoder.eval()
-
-            with torch.inference_mode():
-                for step, batch in enumerate(val_dataloader):
-                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
-                    latents = latents * 0.18215
-
-                    noise = torch.randn_like(latents)
-                    bsz = latents.shape[0]
-                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
-                                              (bsz,), device=latents.device)
-                    timesteps = timesteps.long()
-
-                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
-                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
-
-                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
-                    # Get the target for loss depending on the prediction type
-                    if noise_scheduler.config.prediction_type == "epsilon":
-                        target = noise
-                    elif noise_scheduler.config.prediction_type == "v_prediction":
-                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
-                    else:
-                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
-                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
-
-                    acc = (model_pred == latents).float().mean()
-
-                    avg_loss_val.update(loss.detach_(), bsz)
-                    avg_acc_val.update(acc.detach_(), bsz)
-
-                    if accelerator.sync_gradients:
-                        local_progress_bar.update(1)
-                        global_progress_bar.update(1)
-
-                    logs = {
-                        "val/loss": avg_loss_val.avg.item(),
-                        "val/acc": avg_acc_val.avg.item(),
-                        "val/cur_loss": loss.item(),
-                        "val/cur_acc": acc.item(),
-                    }
-                    local_progress_bar.set_postfix(**logs)
-
-            accelerator.log({
-                "val/loss": avg_loss_val.avg.item(),
-                "val/acc": avg_acc_val.avg.item(),
-            }, step=global_step)
-
-            local_progress_bar.clear()
-            global_progress_bar.clear()
-
-            if avg_acc_val.avg.item() > max_acc_val:
-                accelerator.print(
-                    f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
-                max_acc_val = avg_acc_val.avg.item()
-
-            if accelerator.is_main_process:
-                if (epoch + 1) % args.sample_frequency == 0:
-                    checkpointer.save_samples(global_step, args.sample_steps)
-
-        # Create the pipeline using using the trained modules and save it.
-        if accelerator.is_main_process:
-            print("Finished! Saving final checkpoint and resume state.")
-            checkpointer.save_model()
-
-            accelerator.end_training()
-
-    except KeyboardInterrupt:
-        if accelerator.is_main_process:
-            print("Interrupted, saving checkpoint and resume state...")
-            checkpointer.save_model()
-            accelerator.end_training()
-        quit()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/textual_inversion.py b/textual_inversion.py
deleted file mode 100644
index e281c73..0000000
--- a/textual_inversion.py
+++ /dev/null
@@ -1,1034 +0,0 @@
-import argparse
-import itertools
-import math
-import os
-import datetime
-import logging
-import json
-from pathlib import Path
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-
-from accelerate import Accelerator
-from accelerate.logging import get_logger
-from accelerate.utils import LoggerType, set_seed
-from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
-from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
-from PIL import Image
-from tqdm.auto import tqdm
-from transformers import CLIPTextModel, CLIPTokenizer
-from slugify import slugify
-
-from common import load_text_embeddings, load_text_embedding
-from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
-from pipelines.util import set_use_memory_efficient_attention_xformers
-from data.csv import CSVDataModule, CSVDataItem
-from training.optimization import get_one_cycle_schedule
-from models.clip.prompt import PromptProcessor
-
-logger = get_logger(__name__)
-
-
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.benchmark = True
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description="Simple example of a training script."
-    )
-    parser.add_argument(
-        "--pretrained_model_name_or_path",
-        type=str,
-        default=None,
-        help="Path to pretrained model or model identifier from huggingface.co/models.",
-    )
-    parser.add_argument(
-        "--tokenizer_name",
-        type=str,
-        default=None,
-        help="Pretrained tokenizer name or path if not the same as model_name",
-    )
-    parser.add_argument(
-        "--train_data_file",
-        type=str,
-        default=None,
-        help="A CSV file containing the training data."
-    )
-    parser.add_argument(
-        "--train_data_template",
-        type=str,
-        default="template",
-    )
-    parser.add_argument(
-        "--instance_identifier",
-        type=str,
-        default=None,
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--class_identifier",
-        type=str,
-        default=None,
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--placeholder_token",
-        type=str,
-        nargs='*',
-        help="A token to use as a placeholder for the concept.",
-    )
-    parser.add_argument(
-        "--initializer_token",
-        type=str,
-        nargs='*',
-        help="A token to use as initializer word."
-    )
-    parser.add_argument(
-        "--num_class_images",
-        type=int,
-        default=400,
-        help="How many class images to generate."
-    )
-    parser.add_argument(
-        "--repeats",
-        type=int,
-        default=1,
-        help="How many times to repeat the training data."
-    )
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        default="output/text-inversion",
-        help="The output directory where the model predictions and checkpoints will be written.",
-    )
-    parser.add_argument(
-        "--embeddings_dir",
-        type=str,
-        default=None,
-        help="The embeddings directory where Textual Inversion embeddings are stored.",
-    )
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default=None,
-        help="A mode to filter the dataset.",
-    )
-    parser.add_argument(
-        "--seed",
-        type=int,
-        default=None,
-        help="A seed for reproducible training.")
-    parser.add_argument(
-        "--resolution",
-        type=int,
-        default=768,
-        help=(
-            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
-            " resolution"
-        ),
-    )
-    parser.add_argument(
-        "--center_crop",
-        action="store_true",
-        help="Whether to center crop images before resizing to resolution"
-    )
-    parser.add_argument(
-        "--tag_dropout",
-        type=float,
-        default=0,
-        help="Tag dropout probability.",
-    )
-    parser.add_argument(
-        "--dataloader_num_workers",
-        type=int,
-        default=0,
-        help=(
-            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
-            " process."
-        ),
-    )
-    parser.add_argument(
-        "--num_train_epochs",
-        type=int,
-        default=100
-    )
-    parser.add_argument(
-        "--max_train_steps",
-        type=int,
-        default=None,
-        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
-    )
-    parser.add_argument(
-        "--gradient_accumulation_steps",
-        type=int,
-        default=1,
-        help="Number of updates steps to accumulate before performing a backward/update pass.",
-    )
-    parser.add_argument(
-        "--gradient_checkpointing",
-        action="store_true",
-        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
-    )
-    parser.add_argument(
-        "--learning_rate",
-        type=float,
-        default=1e-4,
-        help="Initial learning rate (after the potential warmup period) to use.",
-    )
-    parser.add_argument(
-        "--scale_lr",
-        action="store_true",
-        default=True,
-        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
-    )
-    parser.add_argument(
-        "--lr_scheduler",
-        type=str,
-        default="one_cycle",
-        help=(
-            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
-            ' "constant", "constant_with_warmup", "one_cycle"]'
-        ),
-    )
-    parser.add_argument(
-        "--lr_warmup_epochs",
-        type=int,
-        default=10,
-        help="Number of steps for the warmup in the lr scheduler."
-    )
-    parser.add_argument(
-        "--lr_cycles",
-        type=int,
-        default=None,
-        help="Number of restart cycles in the lr scheduler."
-    )
-    parser.add_argument(
-        "--use_8bit_adam",
-        action="store_true",
-        help="Whether or not to use 8-bit Adam from bitsandbytes."
-    )
-    parser.add_argument(
-        "--adam_beta1",
-        type=float,
-        default=0.9,
-        help="The beta1 parameter for the Adam optimizer."
-    )
-    parser.add_argument(
-        "--adam_beta2",
-        type=float,
-        default=0.999,
-        help="The beta2 parameter for the Adam optimizer."
-    )
-    parser.add_argument(
-        "--adam_weight_decay",
-        type=float,
-        default=1e-2,
-        help="Weight decay to use."
-    )
-    parser.add_argument(
-        "--adam_epsilon",
-        type=float,
-        default=1e-08,
-        help="Epsilon value for the Adam optimizer"
-    )
-    parser.add_argument(
-        "--mixed_precision",
-        type=str,
-        default="no",
-        choices=["no", "fp16", "bf16"],
-        help=(
-            "Whether to use mixed precision. Choose"
-            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
-            "and an Nvidia Ampere GPU."
-        ),
-    )
-    parser.add_argument(
-        "--checkpoint_frequency",
-        type=int,
-        default=5,
-        help="How often to save a checkpoint and sample image (in epochs)",
-    )
-    parser.add_argument(
-        "--sample_frequency",
-        type=int,
-        default=1,
-        help="How often to save a checkpoint and sample image (in epochs)",
-    )
-    parser.add_argument(
-        "--sample_image_size",
-        type=int,
-        default=768,
-        help="Size of sample images",
-    )
-    parser.add_argument(
-        "--sample_batches",
-        type=int,
-        default=1,
-        help="Number of sample batches to generate per checkpoint",
-    )
-    parser.add_argument(
-        "--sample_batch_size",
-        type=int,
-        default=1,
-        help="Number of samples to generate per batch",
-    )
-    parser.add_argument(
-        "--valid_set_size",
-        type=int,
-        default=None,
-        help="Number of images in the validation dataset."
-    )
-    parser.add_argument(
-        "--train_batch_size",
-        type=int,
-        default=1,
-        help="Batch size (per device) for the training dataloader."
-    )
-    parser.add_argument(
-        "--sample_steps",
-        type=int,
-        default=15,
-        help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
-    )
-    parser.add_argument(
-        "--prior_loss_weight",
-        type=float,
-        default=1.0,
-        help="The weight of prior preservation loss."
-    )
-    parser.add_argument(
-        "--noise_timesteps",
-        type=int,
-        default=1000,
-    )
-    parser.add_argument(
-        "--resume_from",
-        type=str,
-        default=None,
-        help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
-    )
-    parser.add_argument(
-        "--global_step",
-        type=int,
-        default=0,
-    )
-    parser.add_argument(
-        "--config",
-        type=str,
-        default=None,
-        help="Path to a JSON configuration file containing arguments for invoking this script."
-    )
-
-    args = parser.parse_args()
-    if args.config is not None:
-        with open(args.config, 'rt') as f:
-            args = parser.parse_args(
-                namespace=argparse.Namespace(**json.load(f)["args"]))
-
-    if args.train_data_file is None:
-        raise ValueError("You must specify --train_data_file")
-
-    if args.pretrained_model_name_or_path is None:
-        raise ValueError("You must specify --pretrained_model_name_or_path")
-
-    if isinstance(args.initializer_token, str):
-        args.initializer_token = [args.initializer_token]
-
-    if len(args.initializer_token) == 0:
-        raise ValueError("You must specify --initializer_token")
-
-    if isinstance(args.placeholder_token, str):
-        args.placeholder_token = [args.placeholder_token]
-
-    if len(args.placeholder_token) == 0:
-        args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
-
-    if len(args.placeholder_token) != len(args.initializer_token):
-        raise ValueError("You must specify --placeholder_token")
-
-    if args.output_dir is None:
-        raise ValueError("You must specify --output_dir")
-
-    return args
-
-
-def freeze_params(params):
-    for param in params:
-        param.requires_grad = False
-
-
-def save_args(basepath: Path, args, extra={}):
-    info = {"args": vars(args)}
-    info["args"].update(extra)
-    with open(basepath.joinpath("args.json"), "w") as f:
-        json.dump(info, f, indent=4)
-
-
-def make_grid(images, rows, cols):
-    w, h = images[0].size
-    grid = Image.new('RGB', size=(cols*w, rows*h))
-    for i, image in enumerate(images):
-        grid.paste(image, box=(i % cols*w, i//cols*h))
-    return grid
-
-
-class Checkpointer:
-    def __init__(
-        self,
-        datamodule,
-        accelerator,
-        vae,
-        unet,
-        tokenizer,
-        text_encoder,
-        scheduler,
-        instance_identifier,
-        placeholder_token,
-        placeholder_token_id,
-        output_dir: Path,
-        sample_image_size,
-        sample_batches,
-        sample_batch_size,
-        seed
-    ):
-        self.datamodule = datamodule
-        self.accelerator = accelerator
-        self.vae = vae
-        self.unet = unet
-        self.tokenizer = tokenizer
-        self.text_encoder = text_encoder
-        self.scheduler = scheduler
-        self.instance_identifier = instance_identifier
-        self.placeholder_token = placeholder_token
-        self.placeholder_token_id = placeholder_token_id
-        self.output_dir = output_dir
-        self.sample_image_size = sample_image_size
-        self.seed = seed or torch.random.seed()
-        self.sample_batches = sample_batches
-        self.sample_batch_size = sample_batch_size
-
-    @torch.no_grad()
-    def checkpoint(self, step, postfix):
-        print("Saving checkpoint for step %d..." % step)
-
-        checkpoints_path = self.output_dir.joinpath("checkpoints")
-        checkpoints_path.mkdir(parents=True, exist_ok=True)
-
-        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
-
-        for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
-            # Save a checkpoint
-            learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
-            learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
-
-            filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
-            torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
-
-        del text_encoder
-        del learned_embeds
-
-    @torch.no_grad()
-    def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
-        samples_path = Path(self.output_dir).joinpath("samples")
-
-        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
-
-        # Save a sample image
-        pipeline = VlpnStableDiffusion(
-            text_encoder=text_encoder,
-            vae=self.vae,
-            unet=self.unet,
-            tokenizer=self.tokenizer,
-            scheduler=self.scheduler,
-        ).to(self.accelerator.device)
-        pipeline.set_progress_bar_config(dynamic_ncols=True)
-
-        train_data = self.datamodule.train_dataloader()
-        val_data = self.datamodule.val_dataloader()
-
-        generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
-        stable_latents = torch.randn(
-            (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
-            device=pipeline.device,
-            generator=generator,
-        )
-
-        with torch.autocast("cuda"), torch.inference_mode():
-            for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
-                all_samples = []
-                file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
-                file_path.parent.mkdir(parents=True, exist_ok=True)
-
-                data_enum = enumerate(data)
-
-                batches = [
-                    batch
-                    for j, batch in data_enum
-                    if j * data.batch_size < self.sample_batch_size * self.sample_batches
-                ]
-                prompts = [
-                    prompt.format(identifier=self.instance_identifier)
-                    for batch in batches
-                    for prompt in batch["prompts"]
-                ]
-                nprompts = [
-                    prompt
-                    for batch in batches
-                    for prompt in batch["nprompts"]
-                ]
-
-                for i in range(self.sample_batches):
-                    prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
-                    nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
-
-                    samples = pipeline(
-                        prompt=prompt,
-                        negative_prompt=nprompt,
-                        height=self.sample_image_size,
-                        width=self.sample_image_size,
-                        image=latents[:len(prompt)] if latents is not None else None,
-                        generator=generator if latents is not None else None,
-                        guidance_scale=guidance_scale,
-                        eta=eta,
-                        num_inference_steps=num_inference_steps,
-                        output_type='pil'
-                    ).images
-
-                    all_samples += samples
-
-                    del samples
-
-                image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
-                image_grid.save(file_path, quality=85)
-
-                del all_samples
-                del image_grid
-
-        del text_encoder
-        del pipeline
-        del generator
-        del stable_latents
-
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-
-
-def main():
-    args = parse_args()
-
-    instance_identifier = args.instance_identifier
-
-    if len(args.placeholder_token) != 0:
-        instance_identifier = instance_identifier.format(args.placeholder_token[0])
-
-    global_step_offset = args.global_step
-    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
-    basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
-    basepath.mkdir(parents=True, exist_ok=True)
-
-    accelerator = Accelerator(
-        log_with=LoggerType.TENSORBOARD,
-        logging_dir=f"{basepath}",
-        gradient_accumulation_steps=args.gradient_accumulation_steps,
-        mixed_precision=args.mixed_precision
-    )
-
-    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
-
-    args.seed = args.seed or (torch.random.seed() >> 32)
-    set_seed(args.seed)
-
-    # Load the tokenizer and add the placeholder token as a additional special token
-    if args.tokenizer_name:
-        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
-    elif args.pretrained_model_name_or_path:
-        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
-
-    # Load models and create wrapper for stable diffusion
-    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
-    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
-    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
-    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
-    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
-        args.pretrained_model_name_or_path, subfolder='scheduler')
-
-    vae.enable_slicing()
-    set_use_memory_efficient_attention_xformers(unet, True)
-    set_use_memory_efficient_attention_xformers(vae, True)
-
-    if args.gradient_checkpointing:
-        unet.enable_gradient_checkpointing()
-        text_encoder.gradient_checkpointing_enable()
-
-    if args.embeddings_dir is not None:
-        embeddings_dir = Path(args.embeddings_dir)
-        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
-            raise ValueError("--embeddings_dir must point to an existing directory")
-        added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
-        print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
-
-    # Convert the initializer_token, placeholder_token to ids
-    initializer_token_ids = torch.stack([
-        torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
-        for token in args.initializer_token
-    ])
-
-    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
-    print(f"Added {num_added_tokens} new tokens.")
-
-    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
-
-    # Resize the token embeddings as we are adding new special tokens to the tokenizer
-    text_encoder.resize_token_embeddings(len(tokenizer))
-
-    # Initialise the newly added placeholder token with the embeddings of the initializer token
-    token_embeds = text_encoder.get_input_embeddings().weight.data
-
-    if args.resume_from is not None:
-        resumepath = Path(args.resume_from).joinpath("checkpoints")
-
-        for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
-            load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
-
-    original_token_embeds = token_embeds.clone().to(accelerator.device)
-
-    initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
-    for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
-        token_embeds[token_id] = embeddings
-
-    index_fixed_tokens = torch.arange(len(tokenizer))
-    index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
-
-    # Freeze vae and unet
-    freeze_params(vae.parameters())
-    freeze_params(unet.parameters())
-    # Freeze all parameters except for the token embeddings in text encoder
-    freeze_params(itertools.chain(
-        text_encoder.text_model.encoder.parameters(),
-        text_encoder.text_model.final_layer_norm.parameters(),
-        text_encoder.text_model.embeddings.position_embedding.parameters(),
-    ))
-
-    prompt_processor = PromptProcessor(tokenizer, text_encoder)
-
-    if args.scale_lr:
-        args.learning_rate = (
-            args.learning_rate * args.gradient_accumulation_steps *
-            args.train_batch_size * accelerator.num_processes
-        )
-
-    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
-    if args.use_8bit_adam:
-        try:
-            import bitsandbytes as bnb
-        except ImportError:
-            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
-
-        optimizer_class = bnb.optim.AdamW8bit
-    else:
-        optimizer_class = torch.optim.AdamW
-
-    # Initialize the optimizer
-    optimizer = optimizer_class(
-        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings
-        lr=args.learning_rate,
-        betas=(args.adam_beta1, args.adam_beta2),
-        weight_decay=args.adam_weight_decay,
-        eps=args.adam_epsilon,
-    )
-
-    weight_dtype = torch.float32
-    if args.mixed_precision == "fp16":
-        weight_dtype = torch.float16
-    elif args.mixed_precision == "bf16":
-        weight_dtype = torch.bfloat16
-
-    def keyword_filter(item: CSVDataItem):
-        return any(keyword in item.prompt for keyword in args.placeholder_token)
-
-    def collate_fn(examples):
-        prompts = [example["prompts"] for example in examples]
-        nprompts = [example["nprompts"] for example in examples]
-        input_ids = [example["instance_prompt_ids"] for example in examples]
-        pixel_values = [example["instance_images"] for example in examples]
-
-        # concat class and instance examples for prior preservation
-        if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
-            input_ids += [example["class_prompt_ids"] for example in examples]
-            pixel_values += [example["class_images"] for example in examples]
-
-        pixel_values = torch.stack(pixel_values)
-        pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
-
-        inputs = prompt_processor.unify_input_ids(input_ids)
-
-        batch = {
-            "prompts": prompts,
-            "nprompts": nprompts,
-            "input_ids": inputs.input_ids,
-            "pixel_values": pixel_values,
-            "attention_mask": inputs.attention_mask,
-        }
-        return batch
-
-    datamodule = CSVDataModule(
-        data_file=args.train_data_file,
-        batch_size=args.train_batch_size,
-        prompt_processor=prompt_processor,
-        instance_identifier=args.instance_identifier,
-        class_identifier=args.class_identifier,
-        class_subdir="cls",
-        num_class_images=args.num_class_images,
-        size=args.resolution,
-        repeats=args.repeats,
-        mode=args.mode,
-        dropout=args.tag_dropout,
-        center_crop=args.center_crop,
-        template_key=args.train_data_template,
-        valid_set_size=args.valid_set_size,
-        num_workers=args.dataloader_num_workers,
-        filter=keyword_filter,
-        collate_fn=collate_fn
-    )
-
-    datamodule.prepare_data()
-    datamodule.setup()
-
-    if args.num_class_images != 0:
-        missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
-
-        if len(missing_data) != 0:
-            batched_data = [
-                missing_data[i:i+args.sample_batch_size]
-                for i in range(0, len(missing_data), args.sample_batch_size)
-            ]
-
-            pipeline = VlpnStableDiffusion(
-                text_encoder=text_encoder,
-                vae=vae,
-                unet=unet,
-                tokenizer=tokenizer,
-                scheduler=checkpoint_scheduler,
-            ).to(accelerator.device)
-            pipeline.set_progress_bar_config(dynamic_ncols=True)
-
-            with torch.autocast("cuda"), torch.inference_mode():
-                for batch in batched_data:
-                    image_name = [item.class_image_path for item in batch]
-                    prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
-                    nprompt = [item.nprompt for item in batch]
-
-                    images = pipeline(
-                        prompt=prompt,
-                        negative_prompt=nprompt,
-                        num_inference_steps=args.sample_steps
-                    ).images
-
-                    for i, image in enumerate(images):
-                        image.save(image_name[i])
-
-            del pipeline
-
-            if torch.cuda.is_available():
-                torch.cuda.empty_cache()
-
-    train_dataloader = datamodule.train_dataloader()
-    val_dataloader = datamodule.val_dataloader()
-
-    # Scheduler and math around the number of training steps.
-    overrode_max_train_steps = False
-    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
-    if args.max_train_steps is None:
-        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
-        overrode_max_train_steps = True
-    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-
-    warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
-
-    if args.lr_scheduler == "one_cycle":
-        lr_scheduler = get_one_cycle_schedule(
-            optimizer=optimizer,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-        )
-    elif args.lr_scheduler == "cosine_with_restarts":
-        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
-            optimizer=optimizer,
-            num_warmup_steps=warmup_steps,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
-                ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
-        )
-    else:
-        lr_scheduler = get_scheduler(
-            args.lr_scheduler,
-            optimizer=optimizer,
-            num_warmup_steps=warmup_steps,
-            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-        )
-
-    text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
-        text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
-    )
-
-    # Move vae and unet to device
-    vae.to(accelerator.device, dtype=weight_dtype)
-    unet.to(accelerator.device, dtype=weight_dtype)
-
-    # Keep vae and unet in eval mode as we don't train these
-    vae.eval()
-    unet.eval()
-
-    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
-    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
-    if overrode_max_train_steps:
-        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
-
-    num_val_steps_per_epoch = len(val_dataloader)
-    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
-    val_steps = num_val_steps_per_epoch * num_epochs
-
-    # We need to initialize the trackers we use, and also store our configuration.
-    # The trackers initializes automatically on the main process.
-    if accelerator.is_main_process:
-        config = vars(args).copy()
-        config["initializer_token"] = " ".join(config["initializer_token"])
-        config["placeholder_token"] = " ".join(config["placeholder_token"])
-        accelerator.init_trackers("textual_inversion", config=config)
-
-    # Train!
-    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
-
-    logger.info("***** Running training *****")
-    logger.info(f"  Num Epochs = {num_epochs}")
-    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
-    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
-    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
-    logger.info(f"  Total optimization steps = {args.max_train_steps}")
-    # Only show the progress bar once on each machine.
-
-    global_step = 0
-    min_val_loss = np.inf
-
-    checkpointer = Checkpointer(
-        datamodule=datamodule,
-        accelerator=accelerator,
-        vae=vae,
-        unet=unet,
-        tokenizer=tokenizer,
-        text_encoder=text_encoder,
-        scheduler=checkpoint_scheduler,
-        instance_identifier=args.instance_identifier,
-        placeholder_token=args.placeholder_token,
-        placeholder_token_id=placeholder_token_id,
-        output_dir=basepath,
-        sample_image_size=args.sample_image_size,
-        sample_batch_size=args.sample_batch_size,
-        sample_batches=args.sample_batches,
-        seed=args.seed
-    )
-
-    if accelerator.is_main_process:
-        checkpointer.save_samples(
-            0,
-            args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
-
-    local_progress_bar = tqdm(
-        range(num_update_steps_per_epoch + num_val_steps_per_epoch),
-        disable=not accelerator.is_local_main_process,
-        dynamic_ncols=True
-    )
-    local_progress_bar.set_description("Epoch X / Y")
-
-    global_progress_bar = tqdm(
-        range(args.max_train_steps + val_steps),
-        disable=not accelerator.is_local_main_process,
-        dynamic_ncols=True
-    )
-    global_progress_bar.set_description("Total progress")
-
-    try:
-        for epoch in range(num_epochs):
-            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
-            local_progress_bar.reset()
-
-            text_encoder.train()
-            train_loss = 0.0
-
-            sample_checkpoint = False
-
-            for step, batch in enumerate(train_dataloader):
-                with accelerator.accumulate(text_encoder):
-                    # Convert images to latent space
-                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
-                    latents = latents * 0.18215
-
-                    # Sample noise that we'll add to the latents
-                    noise = torch.randn_like(latents)
-                    bsz = latents.shape[0]
-                    # Sample a random timestep for each image
-                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
-                                              (bsz,), device=latents.device)
-                    timesteps = timesteps.long()
-
-                    # Add noise to the latents according to the noise magnitude at each timestep
-                    # (this is the forward diffusion process)
-                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
-                    # Get the text embedding for conditioning
-                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
-                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
-
-                    # Predict the noise residual
-                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
-                    # Get the target for loss depending on the prediction type
-                    if noise_scheduler.config.prediction_type == "epsilon":
-                        target = noise
-                    elif noise_scheduler.config.prediction_type == "v_prediction":
-                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
-                    else:
-                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
-                    if args.num_class_images != 0:
-                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
-                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
-                        target, target_prior = torch.chunk(target, 2, dim=0)
-
-                        # Compute instance loss
-                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
-
-                        # Compute prior loss
-                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
-
-                        # Add the prior loss to the instance loss.
-                        loss = loss + args.prior_loss_weight * prior_loss
-                    else:
-                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
-
-                    accelerator.backward(loss)
-
-                    optimizer.step()
-                    if not accelerator.optimizer_step_was_skipped:
-                        lr_scheduler.step()
-                    optimizer.zero_grad(set_to_none=True)
-
-                    # Let's make sure we don't update any embedding weights besides the newly added token
-                    with torch.no_grad():
-                        text_encoder.get_input_embeddings(
-                        ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
-
-                    loss = loss.detach().item()
-                    train_loss += loss
-
-                # Checks if the accelerator has performed an optimization step behind the scenes
-                if accelerator.sync_gradients:
-                    local_progress_bar.update(1)
-                    global_progress_bar.update(1)
-
-                    global_step += 1
-
-                logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
-
-                accelerator.log(logs, step=global_step)
-
-                local_progress_bar.set_postfix(**logs)
-
-                if global_step >= args.max_train_steps:
-                    break
-
-            train_loss /= len(train_dataloader)
-
-            accelerator.wait_for_everyone()
-
-            text_encoder.eval()
-            val_loss = 0.0
-
-            with torch.inference_mode():
-                for step, batch in enumerate(val_dataloader):
-                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
-                    latents = latents * 0.18215
-
-                    noise = torch.randn_like(latents)
-                    bsz = latents.shape[0]
-                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
-                                              (bsz,), device=latents.device)
-                    timesteps = timesteps.long()
-
-                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
-                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
-                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
-
-                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
-                    # Get the target for loss depending on the prediction type
-                    if noise_scheduler.config.prediction_type == "epsilon":
-                        target = noise
-                    elif noise_scheduler.config.prediction_type == "v_prediction":
-                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
-                    else:
-                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
-
-                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
-
-                    loss = loss.detach().item()
-                    val_loss += loss
-
-                    if accelerator.sync_gradients:
-                        local_progress_bar.update(1)
-                        global_progress_bar.update(1)
-
-                    logs = {"val/loss": loss}
-                    local_progress_bar.set_postfix(**logs)
-
-            val_loss /= len(val_dataloader)
-
-            accelerator.log({"val/loss": val_loss}, step=global_step)
-
-            local_progress_bar.clear()
-            global_progress_bar.clear()
-
-            if accelerator.is_main_process:
-                if min_val_loss > val_loss:
-                    accelerator.print(
-                        f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
-                    checkpointer.checkpoint(global_step + global_step_offset, "milestone")
-                    min_val_loss = val_loss
-
-                if (epoch + 1) % args.checkpoint_frequency == 0:
-                    checkpointer.checkpoint(global_step + global_step_offset, "training")
-                    save_args(basepath, args, {
-                        "global_step": global_step + global_step_offset
-                    })
-
-                if (epoch + 1) % args.sample_frequency == 0:
-                    checkpointer.save_samples(
-                        global_step + global_step_offset,
-                        args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
-
-        # Create the pipeline using using the trained modules and save it.
-        if accelerator.is_main_process:
-            print("Finished! Saving final checkpoint and resume state.")
-            checkpointer.checkpoint(global_step + global_step_offset, "end")
-            save_args(basepath, args, {
-                "global_step": global_step + global_step_offset
-            })
-            accelerator.end_training()
-
-    except KeyboardInterrupt:
-        if accelerator.is_main_process:
-            print("Interrupted, saving checkpoint and resume state...")
-            checkpointer.checkpoint(global_step + global_step_offset, "end")
-            save_args(basepath, args, {
-                "global_step": global_step + global_step_offset
-            })
-            accelerator.end_training()
-        quit()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/train_dreambooth.py b/train_dreambooth.py
new file mode 100644
index 0000000..3eecf9c
--- /dev/null
+++ b/train_dreambooth.py
@@ -0,0 +1,1133 @@
+import argparse
+import itertools
+import math
+import datetime
+import logging
+import json
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import LoggerType, set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
+from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
+from diffusers.training_utils import EMAModel
+from PIL import Image
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from slugify import slugify
+
+from common import load_text_embeddings
+from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
+from pipelines.util import set_use_memory_efficient_attention_xformers
+from data.csv import CSVDataModule
+from training.optimization import get_one_cycle_schedule
+from models.clip.prompt import PromptProcessor
+
+logger = get_logger(__name__)
+
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.benchmark = True
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="Simple example of a training script."
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--tokenizer_name",
+        type=str,
+        default=None,
+        help="Pretrained tokenizer name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--train_data_file",
+        type=str,
+        default=None,
+        help="A folder containing the training data."
+    )
+    parser.add_argument(
+        "--train_data_template",
+        type=str,
+        default="template",
+    )
+    parser.add_argument(
+        "--instance_identifier",
+        type=str,
+        default=None,
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--class_identifier",
+        type=str,
+        default=None,
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--placeholder_token",
+        type=str,
+        nargs='*',
+        default=[],
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--initializer_token",
+        type=str,
+        nargs='*',
+        default=[],
+        help="A token to use as initializer word."
+    )
+    parser.add_argument(
+        "--train_text_encoder",
+        action="store_true",
+        default=True,
+        help="Whether to train the whole text encoder."
+    )
+    parser.add_argument(
+        "--train_text_encoder_epochs",
+        default=999999,
+        help="Number of epochs the text encoder will be trained."
+    )
+    parser.add_argument(
+        "--tag_dropout",
+        type=float,
+        default=0.1,
+        help="Tag dropout probability.",
+    )
+    parser.add_argument(
+        "--num_class_images",
+        type=int,
+        default=400,
+        help="How many class images to generate."
+    )
+    parser.add_argument(
+        "--repeats",
+        type=int,
+        default=1,
+        help="How many times to repeat the training data."
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="output/dreambooth",
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--embeddings_dir",
+        type=str,
+        default=None,
+        help="The embeddings directory where Textual Inversion embeddings are stored.",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default=None,
+        help="A mode to filter the dataset.",
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=None,
+        help="A seed for reproducible training."
+    )
+    parser.add_argument(
+        "--resolution",
+        type=int,
+        default=768,
+        help=(
+            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+            " resolution"
+        ),
+    )
+    parser.add_argument(
+        "--center_crop",
+        action="store_true",
+        help="Whether to center crop images before resizing to resolution"
+    )
+    parser.add_argument(
+        "--dataloader_num_workers",
+        type=int,
+        default=0,
+        help=(
+            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+            " process."
+        ),
+    )
+    parser.add_argument(
+        "--num_train_epochs",
+        type=int,
+        default=100
+    )
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--gradient_checkpointing",
+        action="store_true",
+        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+    )
+    parser.add_argument(
+        "--learning_rate_unet",
+        type=float,
+        default=2e-6,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--learning_rate_text",
+        type=float,
+        default=2e-6,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--scale_lr",
+        action="store_true",
+        default=True,
+        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+    )
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="one_cycle",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "constant_with_warmup", "one_cycle"]'
+        ),
+    )
+    parser.add_argument(
+        "--lr_warmup_epochs",
+        type=int,
+        default=10,
+        help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument(
+        "--lr_cycles",
+        type=int,
+        default=None,
+        help="Number of restart cycles in the lr scheduler (if supported)."
+    )
+    parser.add_argument(
+        "--use_ema",
+        action="store_true",
+        default=True,
+        help="Whether to use EMA model."
+    )
+    parser.add_argument(
+        "--ema_inv_gamma",
+        type=float,
+        default=1.0
+    )
+    parser.add_argument(
+        "--ema_power",
+        type=float,
+        default=6/7
+    )
+    parser.add_argument(
+        "--ema_max_decay",
+        type=float,
+        default=0.9999
+    )
+    parser.add_argument(
+        "--use_8bit_adam",
+        action="store_true",
+        default=True,
+        help="Whether or not to use 8-bit Adam from bitsandbytes."
+    )
+    parser.add_argument(
+        "--adam_beta1",
+        type=float,
+        default=0.9,
+        help="The beta1 parameter for the Adam optimizer."
+    )
+    parser.add_argument(
+        "--adam_beta2",
+        type=float,
+        default=0.999,
+        help="The beta2 parameter for the Adam optimizer."
+    )
+    parser.add_argument(
+        "--adam_weight_decay",
+        type=float,
+        default=1e-2,
+        help="Weight decay to use."
+    )
+    parser.add_argument(
+        "--adam_epsilon",
+        type=float,
+        default=1e-08,
+        help="Epsilon value for the Adam optimizer"
+    )
+    parser.add_argument(
+        "--mixed_precision",
+        type=str,
+        default="no",
+        choices=["no", "fp16", "bf16"],
+        help=(
+            "Whether to use mixed precision. Choose"
+            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+            "and an Nvidia Ampere GPU."
+        ),
+    )
+    parser.add_argument(
+        "--sample_frequency",
+        type=int,
+        default=1,
+        help="How often to save a checkpoint and sample image",
+    )
+    parser.add_argument(
+        "--sample_image_size",
+        type=int,
+        default=768,
+        help="Size of sample images",
+    )
+    parser.add_argument(
+        "--sample_batches",
+        type=int,
+        default=1,
+        help="Number of sample batches to generate per checkpoint",
+    )
+    parser.add_argument(
+        "--sample_batch_size",
+        type=int,
+        default=1,
+        help="Number of samples to generate per batch",
+    )
+    parser.add_argument(
+        "--valid_set_size",
+        type=int,
+        default=None,
+        help="Number of images in the validation dataset."
+    )
+    parser.add_argument(
+        "--train_batch_size",
+        type=int,
+        default=1,
+        help="Batch size (per device) for the training dataloader."
+    )
+    parser.add_argument(
+        "--sample_steps",
+        type=int,
+        default=15,
+        help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
+    )
+    parser.add_argument(
+        "--prior_loss_weight",
+        type=float,
+        default=1.0,
+        help="The weight of prior preservation loss."
+    )
+    parser.add_argument(
+        "--max_grad_norm",
+        default=1.0,
+        type=float,
+        help="Max gradient norm."
+    )
+    parser.add_argument(
+        "--noise_timesteps",
+        type=int,
+        default=1000,
+    )
+    parser.add_argument(
+        "--config",
+        type=str,
+        default=None,
+        help="Path to a JSON configuration file containing arguments for invoking this script."
+    )
+
+    args = parser.parse_args()
+    if args.config is not None:
+        with open(args.config, 'rt') as f:
+            args = parser.parse_args(
+                namespace=argparse.Namespace(**json.load(f)["args"]))
+
+    if args.train_data_file is None:
+        raise ValueError("You must specify --train_data_file")
+
+    if args.pretrained_model_name_or_path is None:
+        raise ValueError("You must specify --pretrained_model_name_or_path")
+
+    if args.instance_identifier is None:
+        raise ValueError("You must specify --instance_identifier")
+
+    if isinstance(args.initializer_token, str):
+        args.initializer_token = [args.initializer_token]
+
+    if isinstance(args.placeholder_token, str):
+        args.placeholder_token = [args.placeholder_token]
+
+    if len(args.placeholder_token) == 0:
+        args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
+
+    if len(args.placeholder_token) != len(args.initializer_token):
+        raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
+
+    if args.output_dir is None:
+        raise ValueError("You must specify --output_dir")
+
+    return args
+
+
+def save_args(basepath: Path, args, extra={}):
+    info = {"args": vars(args)}
+    info["args"].update(extra)
+    with open(basepath.joinpath("args.json"), "w") as f:
+        json.dump(info, f, indent=4)
+
+
+def freeze_params(params):
+    for param in params:
+        param.requires_grad = False
+
+
+def make_grid(images, rows, cols):
+    w, h = images[0].size
+    grid = Image.new('RGB', size=(cols*w, rows*h))
+    for i, image in enumerate(images):
+        grid.paste(image, box=(i % cols*w, i//cols*h))
+    return grid
+
+
+class AverageMeter:
+    def __init__(self, name=None):
+        self.name = name
+        self.reset()
+
+    def reset(self):
+        self.sum = self.count = self.avg = 0
+
+    def update(self, val, n=1):
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+class Checkpointer:
+    def __init__(
+        self,
+        datamodule,
+        accelerator,
+        vae,
+        unet,
+        ema_unet,
+        tokenizer,
+        text_encoder,
+        scheduler,
+        output_dir: Path,
+        instance_identifier,
+        placeholder_token,
+        placeholder_token_id,
+        sample_image_size,
+        sample_batches,
+        sample_batch_size,
+        seed
+    ):
+        self.datamodule = datamodule
+        self.accelerator = accelerator
+        self.vae = vae
+        self.unet = unet
+        self.ema_unet = ema_unet
+        self.tokenizer = tokenizer
+        self.text_encoder = text_encoder
+        self.scheduler = scheduler
+        self.output_dir = output_dir
+        self.instance_identifier = instance_identifier
+        self.placeholder_token = placeholder_token
+        self.placeholder_token_id = placeholder_token_id
+        self.sample_image_size = sample_image_size
+        self.seed = seed or torch.random.seed()
+        self.sample_batches = sample_batches
+        self.sample_batch_size = sample_batch_size
+
+    @torch.no_grad()
+    def save_model(self):
+        print("Saving model...")
+
+        unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        pipeline = VlpnStableDiffusion(
+            text_encoder=text_encoder,
+            vae=self.vae,
+            unet=unet,
+            tokenizer=self.tokenizer,
+            scheduler=self.scheduler,
+        )
+        pipeline.save_pretrained(self.output_dir.joinpath("model"))
+
+        del unet
+        del text_encoder
+        del pipeline
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+    @torch.no_grad()
+    def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
+        samples_path = Path(self.output_dir).joinpath("samples")
+
+        unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        pipeline = VlpnStableDiffusion(
+            text_encoder=text_encoder,
+            vae=self.vae,
+            unet=unet,
+            tokenizer=self.tokenizer,
+            scheduler=self.scheduler,
+        ).to(self.accelerator.device)
+        pipeline.set_progress_bar_config(dynamic_ncols=True)
+
+        train_data = self.datamodule.train_dataloader()
+        val_data = self.datamodule.val_dataloader()
+
+        generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
+        stable_latents = torch.randn(
+            (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
+            device=pipeline.device,
+            generator=generator,
+        )
+
+        with torch.autocast("cuda"), torch.inference_mode():
+            for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
+                all_samples = []
+                file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
+                file_path.parent.mkdir(parents=True, exist_ok=True)
+
+                data_enum = enumerate(data)
+
+                batches = [
+                    batch
+                    for j, batch in data_enum
+                    if j * data.batch_size < self.sample_batch_size * self.sample_batches
+                ]
+                prompts = [
+                    prompt.format(identifier=self.instance_identifier)
+                    for batch in batches
+                    for prompt in batch["prompts"]
+                ]
+                nprompts = [
+                    prompt
+                    for batch in batches
+                    for prompt in batch["nprompts"]
+                ]
+
+                for i in range(self.sample_batches):
+                    prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
+                    nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
+
+                    samples = pipeline(
+                        prompt=prompt,
+                        negative_prompt=nprompt,
+                        height=self.sample_image_size,
+                        width=self.sample_image_size,
+                        image=latents[:len(prompt)] if latents is not None else None,
+                        generator=generator if latents is not None else None,
+                        guidance_scale=guidance_scale,
+                        eta=eta,
+                        num_inference_steps=num_inference_steps,
+                        output_type='pil'
+                    ).images
+
+                    all_samples += samples
+
+                    del samples
+
+                image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
+                image_grid.save(file_path, quality=85)
+
+                del all_samples
+                del image_grid
+
+        del unet
+        del text_encoder
+        del pipeline
+        del generator
+        del stable_latents
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+
+def main():
+    args = parse_args()
+
+    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+        raise ValueError(
+            "Gradient accumulation is not supported when training the text encoder in distributed training. "
+            "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+        )
+
+    instance_identifier = args.instance_identifier
+
+    if len(args.placeholder_token) != 0:
+        instance_identifier = instance_identifier.format(args.placeholder_token[0])
+
+    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+    basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
+    basepath.mkdir(parents=True, exist_ok=True)
+
+    accelerator = Accelerator(
+        log_with=LoggerType.TENSORBOARD,
+        logging_dir=f"{basepath}",
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision
+    )
+
+    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
+
+    args.seed = args.seed or (torch.random.seed() >> 32)
+    set_seed(args.seed)
+
+    save_args(basepath, args)
+
+    # Load the tokenizer and add the placeholder token as a additional special token
+    if args.tokenizer_name:
+        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+    elif args.pretrained_model_name_or_path:
+        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
+
+    # Load models and create wrapper for stable diffusion
+    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
+    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
+    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
+    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
+    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
+        args.pretrained_model_name_or_path, subfolder='scheduler')
+
+    vae.enable_slicing()
+    set_use_memory_efficient_attention_xformers(unet, True)
+    set_use_memory_efficient_attention_xformers(vae, True)
+
+    if args.gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+        text_encoder.gradient_checkpointing_enable()
+
+    ema_unet = None
+    if args.use_ema:
+        ema_unet = EMAModel(
+            unet,
+            inv_gamma=args.ema_inv_gamma,
+            power=args.ema_power,
+            max_value=args.ema_max_decay,
+            device=accelerator.device
+        )
+
+    # Freeze text_encoder and vae
+    vae.requires_grad_(False)
+
+    if args.embeddings_dir is not None:
+        embeddings_dir = Path(args.embeddings_dir)
+        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
+            raise ValueError("--embeddings_dir must point to an existing directory")
+        added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
+        print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
+
+    if len(args.placeholder_token) != 0:
+        # Convert the initializer_token, placeholder_token to ids
+        initializer_token_ids = torch.stack([
+            torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
+            for token in args.initializer_token
+        ])
+
+        num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+        print(f"Added {num_added_tokens} new tokens.")
+
+        placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+        # Resize the token embeddings as we are adding new special tokens to the tokenizer
+        text_encoder.resize_token_embeddings(len(tokenizer))
+
+        token_embeds = text_encoder.get_input_embeddings().weight.data
+        original_token_embeds = token_embeds.clone().to(accelerator.device)
+        initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
+
+        for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
+            token_embeds[token_id] = embeddings
+    else:
+        placeholder_token_id = []
+
+    if args.train_text_encoder:
+        print(f"Training entire text encoder.")
+    else:
+        print(f"Training added text embeddings")
+
+        freeze_params(itertools.chain(
+            text_encoder.text_model.encoder.parameters(),
+            text_encoder.text_model.final_layer_norm.parameters(),
+            text_encoder.text_model.embeddings.position_embedding.parameters(),
+        ))
+
+        index_fixed_tokens = torch.arange(len(tokenizer))
+        index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
+
+    prompt_processor = PromptProcessor(tokenizer, text_encoder)
+
+    if args.scale_lr:
+        args.learning_rate_unet = (
+            args.learning_rate_unet * args.gradient_accumulation_steps *
+            args.train_batch_size * accelerator.num_processes
+        )
+        args.learning_rate_text = (
+            args.learning_rate_text * args.gradient_accumulation_steps *
+            args.train_batch_size * accelerator.num_processes
+        )
+
+    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+    if args.use_8bit_adam:
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
+
+        optimizer_class = bnb.optim.AdamW8bit
+    else:
+        optimizer_class = torch.optim.AdamW
+
+    if args.train_text_encoder:
+        text_encoder_params_to_optimize = text_encoder.parameters()
+    else:
+        text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
+
+    # Initialize the optimizer
+    optimizer = optimizer_class(
+        [
+            {
+                'params': unet.parameters(),
+                'lr': args.learning_rate_unet,
+            },
+            {
+                'params': text_encoder_params_to_optimize,
+                'lr': args.learning_rate_text,
+            }
+        ],
+        betas=(args.adam_beta1, args.adam_beta2),
+        weight_decay=args.adam_weight_decay,
+        eps=args.adam_epsilon,
+    )
+
+    weight_dtype = torch.float32
+    if args.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+    elif args.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+
+    def collate_fn(examples):
+        prompts = [example["prompts"] for example in examples]
+        nprompts = [example["nprompts"] for example in examples]
+        input_ids = [example["instance_prompt_ids"] for example in examples]
+        pixel_values = [example["instance_images"] for example in examples]
+
+        # concat class and instance examples for prior preservation
+        if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
+            input_ids += [example["class_prompt_ids"] for example in examples]
+            pixel_values += [example["class_images"] for example in examples]
+
+        pixel_values = torch.stack(pixel_values)
+        pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
+
+        inputs = prompt_processor.unify_input_ids(input_ids)
+
+        batch = {
+            "prompts": prompts,
+            "nprompts": nprompts,
+            "input_ids": inputs.input_ids,
+            "pixel_values": pixel_values,
+            "attention_mask": inputs.attention_mask,
+        }
+        return batch
+
+    datamodule = CSVDataModule(
+        data_file=args.train_data_file,
+        batch_size=args.train_batch_size,
+        prompt_processor=prompt_processor,
+        instance_identifier=instance_identifier,
+        class_identifier=args.class_identifier,
+        class_subdir="cls",
+        num_class_images=args.num_class_images,
+        size=args.resolution,
+        repeats=args.repeats,
+        mode=args.mode,
+        dropout=args.tag_dropout,
+        center_crop=args.center_crop,
+        template_key=args.train_data_template,
+        valid_set_size=args.valid_set_size,
+        num_workers=args.dataloader_num_workers,
+        collate_fn=collate_fn
+    )
+
+    datamodule.prepare_data()
+    datamodule.setup()
+
+    if args.num_class_images != 0:
+        missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
+
+        if len(missing_data) != 0:
+            batched_data = [
+                missing_data[i:i+args.sample_batch_size]
+                for i in range(0, len(missing_data), args.sample_batch_size)
+            ]
+
+            pipeline = VlpnStableDiffusion(
+                text_encoder=text_encoder,
+                vae=vae,
+                unet=unet,
+                tokenizer=tokenizer,
+                scheduler=checkpoint_scheduler,
+            ).to(accelerator.device)
+            pipeline.set_progress_bar_config(dynamic_ncols=True)
+
+            with torch.autocast("cuda"), torch.inference_mode():
+                for batch in batched_data:
+                    image_name = [item.class_image_path for item in batch]
+                    prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
+                    nprompt = [item.nprompt for item in batch]
+
+                    images = pipeline(
+                        prompt=prompt,
+                        negative_prompt=nprompt,
+                        num_inference_steps=args.sample_steps
+                    ).images
+
+                    for i, image in enumerate(images):
+                        image.save(image_name[i])
+
+            del pipeline
+
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+
+    train_dataloader = datamodule.train_dataloader()
+    val_dataloader = datamodule.val_dataloader()
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
+
+    if args.lr_scheduler == "one_cycle":
+        lr_scheduler = get_one_cycle_schedule(
+            optimizer=optimizer,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        )
+    elif args.lr_scheduler == "cosine_with_restarts":
+        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
+            optimizer=optimizer,
+            num_warmup_steps=warmup_steps,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
+                ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
+        )
+    else:
+        lr_scheduler = get_scheduler(
+            args.lr_scheduler,
+            optimizer=optimizer,
+            num_warmup_steps=warmup_steps,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        )
+
+    unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
+        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
+    )
+
+    # Move text_encoder and vae to device
+    vae.to(accelerator.device, dtype=weight_dtype)
+
+    # Keep text_encoder and vae in eval mode as we don't train these
+    vae.eval()
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+    num_val_steps_per_epoch = len(val_dataloader)
+    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+    val_steps = num_val_steps_per_epoch * num_epochs
+
+    # We need to initialize the trackers we use, and also store our configuration.
+    # The trackers initializes automatically on the main process.
+    if accelerator.is_main_process:
+        config = vars(args).copy()
+        config["initializer_token"] = " ".join(config["initializer_token"])
+        config["placeholder_token"] = " ".join(config["placeholder_token"])
+        accelerator.init_trackers("dreambooth", config=config)
+
+    # Train!
+    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num Epochs = {num_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    # Only show the progress bar once on each machine.
+
+    global_step = 0
+
+    avg_loss = AverageMeter()
+    avg_acc = AverageMeter()
+
+    avg_loss_val = AverageMeter()
+    avg_acc_val = AverageMeter()
+
+    max_acc_val = 0.0
+
+    checkpointer = Checkpointer(
+        datamodule=datamodule,
+        accelerator=accelerator,
+        vae=vae,
+        unet=unet,
+        ema_unet=ema_unet,
+        tokenizer=tokenizer,
+        text_encoder=text_encoder,
+        scheduler=checkpoint_scheduler,
+        output_dir=basepath,
+        instance_identifier=instance_identifier,
+        placeholder_token=args.placeholder_token,
+        placeholder_token_id=placeholder_token_id,
+        sample_image_size=args.sample_image_size,
+        sample_batch_size=args.sample_batch_size,
+        sample_batches=args.sample_batches,
+        seed=args.seed
+    )
+
+    if accelerator.is_main_process:
+        checkpointer.save_samples(0, args.sample_steps)
+
+    local_progress_bar = tqdm(
+        range(num_update_steps_per_epoch + num_val_steps_per_epoch),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    local_progress_bar.set_description("Epoch X / Y")
+
+    global_progress_bar = tqdm(
+        range(args.max_train_steps + val_steps),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    global_progress_bar.set_description("Total progress")
+
+    try:
+        for epoch in range(num_epochs):
+            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
+            local_progress_bar.reset()
+
+            unet.train()
+
+            if epoch < args.train_text_encoder_epochs:
+                text_encoder.train()
+            elif epoch == args.train_text_encoder_epochs:
+                freeze_params(text_encoder.parameters())
+
+            sample_checkpoint = False
+
+            for step, batch in enumerate(train_dataloader):
+                with accelerator.accumulate(unet):
+                    # Convert images to latent space
+                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
+                    latents = latents * 0.18215
+
+                    # Sample noise that we'll add to the latents
+                    noise = torch.randn_like(latents)
+                    bsz = latents.shape[0]
+                    # Sample a random timestep for each image
+                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
+                                              (bsz,), device=latents.device)
+                    timesteps = timesteps.long()
+
+                    # Add noise to the latents according to the noise magnitude at each timestep
+                    # (this is the forward diffusion process)
+                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+                    # Get the text embedding for conditioning
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
+
+                    # Predict the noise residual
+                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+                    # Get the target for loss depending on the prediction type
+                    if noise_scheduler.config.prediction_type == "epsilon":
+                        target = noise
+                    elif noise_scheduler.config.prediction_type == "v_prediction":
+                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                    else:
+                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                    if args.num_class_images != 0:
+                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+                        target, target_prior = torch.chunk(target, 2, dim=0)
+
+                        # Compute instance loss
+                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+                        # Compute prior loss
+                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+                        # Add the prior loss to the instance loss.
+                        loss = loss + args.prior_loss_weight * prior_loss
+                    else:
+                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+                    acc = (model_pred == latents).float().mean()
+
+                    accelerator.backward(loss)
+
+                    if accelerator.sync_gradients:
+                        params_to_clip = (
+                            itertools.chain(unet.parameters(), text_encoder.parameters())
+                            if args.train_text_encoder and epoch < args.train_text_encoder_epochs
+                            else unet.parameters()
+                        )
+                        accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+                    optimizer.step()
+                    if not accelerator.optimizer_step_was_skipped:
+                        lr_scheduler.step()
+                    if args.use_ema:
+                        ema_unet.step(unet)
+                    optimizer.zero_grad(set_to_none=True)
+
+                    if not args.train_text_encoder:
+                        # Let's make sure we don't update any embedding weights besides the newly added token
+                        with torch.no_grad():
+                            text_encoder.get_input_embeddings(
+                            ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
+
+                    avg_loss.update(loss.detach_(), bsz)
+                    avg_acc.update(acc.detach_(), bsz)
+
+                # Checks if the accelerator has performed an optimization step behind the scenes
+                if accelerator.sync_gradients:
+                    local_progress_bar.update(1)
+                    global_progress_bar.update(1)
+
+                    global_step += 1
+
+                logs = {
+                    "train/loss": avg_loss.avg.item(),
+                    "train/acc": avg_acc.avg.item(),
+                    "train/cur_loss": loss.item(),
+                    "train/cur_acc": acc.item(),
+                    "lr/unet": lr_scheduler.get_last_lr()[0],
+                    "lr/text": lr_scheduler.get_last_lr()[1]
+                }
+                if args.use_ema:
+                    logs["ema_decay"] = 1 - ema_unet.decay
+
+                accelerator.log(logs, step=global_step)
+
+                local_progress_bar.set_postfix(**logs)
+
+                if global_step >= args.max_train_steps:
+                    break
+
+            accelerator.wait_for_everyone()
+
+            unet.eval()
+            text_encoder.eval()
+
+            with torch.inference_mode():
+                for step, batch in enumerate(val_dataloader):
+                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
+                    latents = latents * 0.18215
+
+                    noise = torch.randn_like(latents)
+                    bsz = latents.shape[0]
+                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
+                                              (bsz,), device=latents.device)
+                    timesteps = timesteps.long()
+
+                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
+
+                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+                    # Get the target for loss depending on the prediction type
+                    if noise_scheduler.config.prediction_type == "epsilon":
+                        target = noise
+                    elif noise_scheduler.config.prediction_type == "v_prediction":
+                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                    else:
+                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+                    acc = (model_pred == latents).float().mean()
+
+                    avg_loss_val.update(loss.detach_(), bsz)
+                    avg_acc_val.update(acc.detach_(), bsz)
+
+                    if accelerator.sync_gradients:
+                        local_progress_bar.update(1)
+                        global_progress_bar.update(1)
+
+                    logs = {
+                        "val/loss": avg_loss_val.avg.item(),
+                        "val/acc": avg_acc_val.avg.item(),
+                        "val/cur_loss": loss.item(),
+                        "val/cur_acc": acc.item(),
+                    }
+                    local_progress_bar.set_postfix(**logs)
+
+            accelerator.log({
+                "val/loss": avg_loss_val.avg.item(),
+                "val/acc": avg_acc_val.avg.item(),
+            }, step=global_step)
+
+            local_progress_bar.clear()
+            global_progress_bar.clear()
+
+            if avg_acc_val.avg.item() > max_acc_val:
+                accelerator.print(
+                    f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
+                max_acc_val = avg_acc_val.avg.item()
+
+            if accelerator.is_main_process:
+                if (epoch + 1) % args.sample_frequency == 0:
+                    checkpointer.save_samples(global_step, args.sample_steps)
+
+        # Create the pipeline using using the trained modules and save it.
+        if accelerator.is_main_process:
+            print("Finished! Saving final checkpoint and resume state.")
+            checkpointer.save_model()
+
+            accelerator.end_training()
+
+    except KeyboardInterrupt:
+        if accelerator.is_main_process:
+            print("Interrupted, saving checkpoint and resume state...")
+            checkpointer.save_model()
+            accelerator.end_training()
+        quit()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/train_ti.py b/train_ti.py
new file mode 100644
index 0000000..dbfe58c
--- /dev/null
+++ b/train_ti.py
@@ -0,0 +1,1032 @@
+import argparse
+import itertools
+import math
+import os
+import datetime
+import logging
+import json
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import LoggerType, set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
+from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
+from PIL import Image
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from slugify import slugify
+
+from common import load_text_embeddings, load_text_embedding
+from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
+from pipelines.util import set_use_memory_efficient_attention_xformers
+from data.csv import CSVDataModule, CSVDataItem
+from training.optimization import get_one_cycle_schedule
+from models.clip.prompt import PromptProcessor
+
+logger = get_logger(__name__)
+
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.benchmark = True
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="Simple example of a training script."
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--tokenizer_name",
+        type=str,
+        default=None,
+        help="Pretrained tokenizer name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--train_data_file",
+        type=str,
+        default=None,
+        help="A CSV file containing the training data."
+    )
+    parser.add_argument(
+        "--train_data_template",
+        type=str,
+        default="template",
+    )
+    parser.add_argument(
+        "--instance_identifier",
+        type=str,
+        default=None,
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--class_identifier",
+        type=str,
+        default=None,
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--placeholder_token",
+        type=str,
+        nargs='*',
+        help="A token to use as a placeholder for the concept.",
+    )
+    parser.add_argument(
+        "--initializer_token",
+        type=str,
+        nargs='*',
+        help="A token to use as initializer word."
+    )
+    parser.add_argument(
+        "--num_class_images",
+        type=int,
+        default=400,
+        help="How many class images to generate."
+    )
+    parser.add_argument(
+        "--repeats",
+        type=int,
+        default=1,
+        help="How many times to repeat the training data."
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="output/text-inversion",
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--embeddings_dir",
+        type=str,
+        default=None,
+        help="The embeddings directory where Textual Inversion embeddings are stored.",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default=None,
+        help="A mode to filter the dataset.",
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=None,
+        help="A seed for reproducible training.")
+    parser.add_argument(
+        "--resolution",
+        type=int,
+        default=768,
+        help=(
+            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+            " resolution"
+        ),
+    )
+    parser.add_argument(
+        "--center_crop",
+        action="store_true",
+        help="Whether to center crop images before resizing to resolution"
+    )
+    parser.add_argument(
+        "--tag_dropout",
+        type=float,
+        default=0,
+        help="Tag dropout probability.",
+    )
+    parser.add_argument(
+        "--dataloader_num_workers",
+        type=int,
+        default=0,
+        help=(
+            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
+            " process."
+        ),
+    )
+    parser.add_argument(
+        "--num_train_epochs",
+        type=int,
+        default=100
+    )
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--gradient_checkpointing",
+        action="store_true",
+        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--scale_lr",
+        action="store_true",
+        default=True,
+        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+    )
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="one_cycle",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "constant_with_warmup", "one_cycle"]'
+        ),
+    )
+    parser.add_argument(
+        "--lr_warmup_epochs",
+        type=int,
+        default=10,
+        help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument(
+        "--lr_cycles",
+        type=int,
+        default=None,
+        help="Number of restart cycles in the lr scheduler."
+    )
+    parser.add_argument(
+        "--use_8bit_adam",
+        action="store_true",
+        help="Whether or not to use 8-bit Adam from bitsandbytes."
+    )
+    parser.add_argument(
+        "--adam_beta1",
+        type=float,
+        default=0.9,
+        help="The beta1 parameter for the Adam optimizer."
+    )
+    parser.add_argument(
+        "--adam_beta2",
+        type=float,
+        default=0.999,
+        help="The beta2 parameter for the Adam optimizer."
+    )
+    parser.add_argument(
+        "--adam_weight_decay",
+        type=float,
+        default=1e-2,
+        help="Weight decay to use."
+    )
+    parser.add_argument(
+        "--adam_epsilon",
+        type=float,
+        default=1e-08,
+        help="Epsilon value for the Adam optimizer"
+    )
+    parser.add_argument(
+        "--mixed_precision",
+        type=str,
+        default="no",
+        choices=["no", "fp16", "bf16"],
+        help=(
+            "Whether to use mixed precision. Choose"
+            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+            "and an Nvidia Ampere GPU."
+        ),
+    )
+    parser.add_argument(
+        "--checkpoint_frequency",
+        type=int,
+        default=5,
+        help="How often to save a checkpoint and sample image (in epochs)",
+    )
+    parser.add_argument(
+        "--sample_frequency",
+        type=int,
+        default=1,
+        help="How often to save a checkpoint and sample image (in epochs)",
+    )
+    parser.add_argument(
+        "--sample_image_size",
+        type=int,
+        default=768,
+        help="Size of sample images",
+    )
+    parser.add_argument(
+        "--sample_batches",
+        type=int,
+        default=1,
+        help="Number of sample batches to generate per checkpoint",
+    )
+    parser.add_argument(
+        "--sample_batch_size",
+        type=int,
+        default=1,
+        help="Number of samples to generate per batch",
+    )
+    parser.add_argument(
+        "--valid_set_size",
+        type=int,
+        default=None,
+        help="Number of images in the validation dataset."
+    )
+    parser.add_argument(
+        "--train_batch_size",
+        type=int,
+        default=1,
+        help="Batch size (per device) for the training dataloader."
+    )
+    parser.add_argument(
+        "--sample_steps",
+        type=int,
+        default=15,
+        help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
+    )
+    parser.add_argument(
+        "--prior_loss_weight",
+        type=float,
+        default=1.0,
+        help="The weight of prior preservation loss."
+    )
+    parser.add_argument(
+        "--noise_timesteps",
+        type=int,
+        default=1000,
+    )
+    parser.add_argument(
+        "--resume_from",
+        type=str,
+        default=None,
+        help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
+    )
+    parser.add_argument(
+        "--global_step",
+        type=int,
+        default=0,
+    )
+    parser.add_argument(
+        "--config",
+        type=str,
+        default=None,
+        help="Path to a JSON configuration file containing arguments for invoking this script."
+    )
+
+    args = parser.parse_args()
+    if args.config is not None:
+        with open(args.config, 'rt') as f:
+            args = parser.parse_args(
+                namespace=argparse.Namespace(**json.load(f)["args"]))
+
+    if args.train_data_file is None:
+        raise ValueError("You must specify --train_data_file")
+
+    if args.pretrained_model_name_or_path is None:
+        raise ValueError("You must specify --pretrained_model_name_or_path")
+
+    if isinstance(args.initializer_token, str):
+        args.initializer_token = [args.initializer_token]
+
+    if len(args.initializer_token) == 0:
+        raise ValueError("You must specify --initializer_token")
+
+    if isinstance(args.placeholder_token, str):
+        args.placeholder_token = [args.placeholder_token]
+
+    if len(args.placeholder_token) == 0:
+        args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
+
+    if len(args.placeholder_token) != len(args.initializer_token):
+        raise ValueError("You must specify --placeholder_token")
+
+    if args.output_dir is None:
+        raise ValueError("You must specify --output_dir")
+
+    return args
+
+
+def freeze_params(params):
+    for param in params:
+        param.requires_grad = False
+
+
+def save_args(basepath: Path, args, extra={}):
+    info = {"args": vars(args)}
+    info["args"].update(extra)
+    with open(basepath.joinpath("args.json"), "w") as f:
+        json.dump(info, f, indent=4)
+
+
+def make_grid(images, rows, cols):
+    w, h = images[0].size
+    grid = Image.new('RGB', size=(cols*w, rows*h))
+    for i, image in enumerate(images):
+        grid.paste(image, box=(i % cols*w, i//cols*h))
+    return grid
+
+
+class Checkpointer:
+    def __init__(
+        self,
+        datamodule,
+        accelerator,
+        vae,
+        unet,
+        tokenizer,
+        text_encoder,
+        scheduler,
+        instance_identifier,
+        placeholder_token,
+        placeholder_token_id,
+        output_dir: Path,
+        sample_image_size,
+        sample_batches,
+        sample_batch_size,
+        seed
+    ):
+        self.datamodule = datamodule
+        self.accelerator = accelerator
+        self.vae = vae
+        self.unet = unet
+        self.tokenizer = tokenizer
+        self.text_encoder = text_encoder
+        self.scheduler = scheduler
+        self.instance_identifier = instance_identifier
+        self.placeholder_token = placeholder_token
+        self.placeholder_token_id = placeholder_token_id
+        self.output_dir = output_dir
+        self.sample_image_size = sample_image_size
+        self.seed = seed or torch.random.seed()
+        self.sample_batches = sample_batches
+        self.sample_batch_size = sample_batch_size
+
+    @torch.no_grad()
+    def checkpoint(self, step, postfix):
+        print("Saving checkpoint for step %d..." % step)
+
+        checkpoints_path = self.output_dir.joinpath("checkpoints")
+        checkpoints_path.mkdir(parents=True, exist_ok=True)
+
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
+            # Save a checkpoint
+            learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
+            learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
+
+            filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
+            torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
+
+        del text_encoder
+        del learned_embeds
+
+    @torch.no_grad()
+    def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
+        samples_path = Path(self.output_dir).joinpath("samples")
+
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        # Save a sample image
+        pipeline = VlpnStableDiffusion(
+            text_encoder=text_encoder,
+            vae=self.vae,
+            unet=self.unet,
+            tokenizer=self.tokenizer,
+            scheduler=self.scheduler,
+        ).to(self.accelerator.device)
+        pipeline.set_progress_bar_config(dynamic_ncols=True)
+
+        train_data = self.datamodule.train_dataloader()
+        val_data = self.datamodule.val_dataloader()
+
+        generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
+        stable_latents = torch.randn(
+            (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
+            device=pipeline.device,
+            generator=generator,
+        )
+
+        with torch.autocast("cuda"), torch.inference_mode():
+            for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
+                all_samples = []
+                file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
+                file_path.parent.mkdir(parents=True, exist_ok=True)
+
+                data_enum = enumerate(data)
+
+                batches = [
+                    batch
+                    for j, batch in data_enum
+                    if j * data.batch_size < self.sample_batch_size * self.sample_batches
+                ]
+                prompts = [
+                    prompt.format(identifier=self.instance_identifier)
+                    for batch in batches
+                    for prompt in batch["prompts"]
+                ]
+                nprompts = [
+                    prompt
+                    for batch in batches
+                    for prompt in batch["nprompts"]
+                ]
+
+                for i in range(self.sample_batches):
+                    prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
+                    nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
+
+                    samples = pipeline(
+                        prompt=prompt,
+                        negative_prompt=nprompt,
+                        height=self.sample_image_size,
+                        width=self.sample_image_size,
+                        image=latents[:len(prompt)] if latents is not None else None,
+                        generator=generator if latents is not None else None,
+                        guidance_scale=guidance_scale,
+                        eta=eta,
+                        num_inference_steps=num_inference_steps,
+                        output_type='pil'
+                    ).images
+
+                    all_samples += samples
+
+                    del samples
+
+                image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
+                image_grid.save(file_path, quality=85)
+
+                del all_samples
+                del image_grid
+
+        del text_encoder
+        del pipeline
+        del generator
+        del stable_latents
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+
+def main():
+    args = parse_args()
+
+    instance_identifier = args.instance_identifier
+
+    if len(args.placeholder_token) != 0:
+        instance_identifier = instance_identifier.format(args.placeholder_token[0])
+
+    global_step_offset = args.global_step
+    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+    basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
+    basepath.mkdir(parents=True, exist_ok=True)
+
+    accelerator = Accelerator(
+        log_with=LoggerType.TENSORBOARD,
+        logging_dir=f"{basepath}",
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision
+    )
+
+    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
+
+    args.seed = args.seed or (torch.random.seed() >> 32)
+    set_seed(args.seed)
+
+    # Load the tokenizer and add the placeholder token as a additional special token
+    if args.tokenizer_name:
+        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+    elif args.pretrained_model_name_or_path:
+        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
+
+    # Load models and create wrapper for stable diffusion
+    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
+    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
+    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
+    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
+    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
+        args.pretrained_model_name_or_path, subfolder='scheduler')
+
+    vae.enable_slicing()
+    set_use_memory_efficient_attention_xformers(unet, True)
+    set_use_memory_efficient_attention_xformers(vae, True)
+
+    if args.gradient_checkpointing:
+        unet.enable_gradient_checkpointing()
+        text_encoder.gradient_checkpointing_enable()
+
+    if args.embeddings_dir is not None:
+        embeddings_dir = Path(args.embeddings_dir)
+        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
+            raise ValueError("--embeddings_dir must point to an existing directory")
+        added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
+        print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
+
+    # Convert the initializer_token, placeholder_token to ids
+    initializer_token_ids = torch.stack([
+        torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
+        for token in args.initializer_token
+    ])
+
+    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
+    print(f"Added {num_added_tokens} new tokens.")
+
+    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
+
+    # Resize the token embeddings as we are adding new special tokens to the tokenizer
+    text_encoder.resize_token_embeddings(len(tokenizer))
+
+    # Initialise the newly added placeholder token with the embeddings of the initializer token
+    token_embeds = text_encoder.get_input_embeddings().weight.data
+
+    if args.resume_from is not None:
+        resumepath = Path(args.resume_from).joinpath("checkpoints")
+
+        for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
+            load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
+
+    original_token_embeds = token_embeds.clone().to(accelerator.device)
+
+    initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
+    for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
+        token_embeds[token_id] = embeddings
+
+    index_fixed_tokens = torch.arange(len(tokenizer))
+    index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
+
+    # Freeze vae and unet
+    freeze_params(vae.parameters())
+    freeze_params(unet.parameters())
+    # Freeze all parameters except for the token embeddings in text encoder
+    freeze_params(itertools.chain(
+        text_encoder.text_model.encoder.parameters(),
+        text_encoder.text_model.final_layer_norm.parameters(),
+        text_encoder.text_model.embeddings.position_embedding.parameters(),
+    ))
+
+    prompt_processor = PromptProcessor(tokenizer, text_encoder)
+
+    if args.scale_lr:
+        args.learning_rate = (
+            args.learning_rate * args.gradient_accumulation_steps *
+            args.train_batch_size * accelerator.num_processes
+        )
+
+    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+    if args.use_8bit_adam:
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
+
+        optimizer_class = bnb.optim.AdamW8bit
+    else:
+        optimizer_class = torch.optim.AdamW
+
+    # Initialize the optimizer
+    optimizer = optimizer_class(
+        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings
+        lr=args.learning_rate,
+        betas=(args.adam_beta1, args.adam_beta2),
+        weight_decay=args.adam_weight_decay,
+        eps=args.adam_epsilon,
+    )
+
+    weight_dtype = torch.float32
+    if args.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+    elif args.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+
+    def keyword_filter(item: CSVDataItem):
+        return any(keyword in item.prompt for keyword in args.placeholder_token)
+
+    def collate_fn(examples):
+        prompts = [example["prompts"] for example in examples]
+        nprompts = [example["nprompts"] for example in examples]
+        input_ids = [example["instance_prompt_ids"] for example in examples]
+        pixel_values = [example["instance_images"] for example in examples]
+
+        # concat class and instance examples for prior preservation
+        if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
+            input_ids += [example["class_prompt_ids"] for example in examples]
+            pixel_values += [example["class_images"] for example in examples]
+
+        pixel_values = torch.stack(pixel_values)
+        pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
+
+        inputs = prompt_processor.unify_input_ids(input_ids)
+
+        batch = {
+            "prompts": prompts,
+            "nprompts": nprompts,
+            "input_ids": inputs.input_ids,
+            "pixel_values": pixel_values,
+            "attention_mask": inputs.attention_mask,
+        }
+        return batch
+
+    datamodule = CSVDataModule(
+        data_file=args.train_data_file,
+        batch_size=args.train_batch_size,
+        prompt_processor=prompt_processor,
+        instance_identifier=args.instance_identifier,
+        class_identifier=args.class_identifier,
+        class_subdir="cls",
+        num_class_images=args.num_class_images,
+        size=args.resolution,
+        repeats=args.repeats,
+        mode=args.mode,
+        dropout=args.tag_dropout,
+        center_crop=args.center_crop,
+        template_key=args.train_data_template,
+        valid_set_size=args.valid_set_size,
+        num_workers=args.dataloader_num_workers,
+        filter=keyword_filter,
+        collate_fn=collate_fn
+    )
+
+    datamodule.prepare_data()
+    datamodule.setup()
+
+    if args.num_class_images != 0:
+        missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
+
+        if len(missing_data) != 0:
+            batched_data = [
+                missing_data[i:i+args.sample_batch_size]
+                for i in range(0, len(missing_data), args.sample_batch_size)
+            ]
+
+            pipeline = VlpnStableDiffusion(
+                text_encoder=text_encoder,
+                vae=vae,
+                unet=unet,
+                tokenizer=tokenizer,
+                scheduler=checkpoint_scheduler,
+            ).to(accelerator.device)
+            pipeline.set_progress_bar_config(dynamic_ncols=True)
+
+            with torch.autocast("cuda"), torch.inference_mode():
+                for batch in batched_data:
+                    image_name = [item.class_image_path for item in batch]
+                    prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
+                    nprompt = [item.nprompt for item in batch]
+
+                    images = pipeline(
+                        prompt=prompt,
+                        negative_prompt=nprompt,
+                        num_inference_steps=args.sample_steps
+                    ).images
+
+                    for i, image in enumerate(images):
+                        image.save(image_name[i])
+
+            del pipeline
+
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+
+    train_dataloader = datamodule.train_dataloader()
+    val_dataloader = datamodule.val_dataloader()
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
+
+    if args.lr_scheduler == "one_cycle":
+        lr_scheduler = get_one_cycle_schedule(
+            optimizer=optimizer,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        )
+    elif args.lr_scheduler == "cosine_with_restarts":
+        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
+            optimizer=optimizer,
+            num_warmup_steps=warmup_steps,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
+                ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
+        )
+    else:
+        lr_scheduler = get_scheduler(
+            args.lr_scheduler,
+            optimizer=optimizer,
+            num_warmup_steps=warmup_steps,
+            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+        )
+
+    text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
+        text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
+    )
+
+    # Move vae and unet to device
+    vae.to(accelerator.device, dtype=weight_dtype)
+    unet.to(accelerator.device, dtype=weight_dtype)
+
+    # Keep vae and unet in eval mode as we don't train these
+    vae.eval()
+    unet.eval()
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+
+    num_val_steps_per_epoch = len(val_dataloader)
+    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+    val_steps = num_val_steps_per_epoch * num_epochs
+
+    # We need to initialize the trackers we use, and also store our configuration.
+    # The trackers initializes automatically on the main process.
+    if accelerator.is_main_process:
+        config = vars(args).copy()
+        config["initializer_token"] = " ".join(config["initializer_token"])
+        config["placeholder_token"] = " ".join(config["placeholder_token"])
+        accelerator.init_trackers("textual_inversion", config=config)
+
+    # Train!
+    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num Epochs = {num_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    # Only show the progress bar once on each machine.
+
+    global_step = 0
+    min_val_loss = np.inf
+
+    checkpointer = Checkpointer(
+        datamodule=datamodule,
+        accelerator=accelerator,
+        vae=vae,
+        unet=unet,
+        tokenizer=tokenizer,
+        text_encoder=text_encoder,
+        scheduler=checkpoint_scheduler,
+        instance_identifier=args.instance_identifier,
+        placeholder_token=args.placeholder_token,
+        placeholder_token_id=placeholder_token_id,
+        output_dir=basepath,
+        sample_image_size=args.sample_image_size,
+        sample_batch_size=args.sample_batch_size,
+        sample_batches=args.sample_batches,
+        seed=args.seed
+    )
+
+    if accelerator.is_main_process:
+        checkpointer.save_samples(
+            0,
+            args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
+
+    local_progress_bar = tqdm(
+        range(num_update_steps_per_epoch + num_val_steps_per_epoch),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    local_progress_bar.set_description("Epoch X / Y")
+
+    global_progress_bar = tqdm(
+        range(args.max_train_steps + val_steps),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    global_progress_bar.set_description("Total progress")
+
+    try:
+        for epoch in range(num_epochs):
+            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
+            local_progress_bar.reset()
+
+            text_encoder.train()
+            train_loss = 0.0
+
+            for step, batch in enumerate(train_dataloader):
+                with accelerator.accumulate(text_encoder):
+                    # Convert images to latent space
+                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+                    latents = latents * 0.18215
+
+                    # Sample noise that we'll add to the latents
+                    noise = torch.randn_like(latents)
+                    bsz = latents.shape[0]
+                    # Sample a random timestep for each image
+                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
+                                              (bsz,), device=latents.device)
+                    timesteps = timesteps.long()
+
+                    # Add noise to the latents according to the noise magnitude at each timestep
+                    # (this is the forward diffusion process)
+                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+                    # Get the text embedding for conditioning
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
+                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
+
+                    # Predict the noise residual
+                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+                    # Get the target for loss depending on the prediction type
+                    if noise_scheduler.config.prediction_type == "epsilon":
+                        target = noise
+                    elif noise_scheduler.config.prediction_type == "v_prediction":
+                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                    else:
+                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                    if args.num_class_images != 0:
+                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+                        target, target_prior = torch.chunk(target, 2, dim=0)
+
+                        # Compute instance loss
+                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+                        # Compute prior loss
+                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+                        # Add the prior loss to the instance loss.
+                        loss = loss + args.prior_loss_weight * prior_loss
+                    else:
+                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+                    accelerator.backward(loss)
+
+                    optimizer.step()
+                    if not accelerator.optimizer_step_was_skipped:
+                        lr_scheduler.step()
+                    optimizer.zero_grad(set_to_none=True)
+
+                    # Let's make sure we don't update any embedding weights besides the newly added token
+                    with torch.no_grad():
+                        text_encoder.get_input_embeddings(
+                        ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
+
+                    loss = loss.detach().item()
+                    train_loss += loss
+
+                # Checks if the accelerator has performed an optimization step behind the scenes
+                if accelerator.sync_gradients:
+                    local_progress_bar.update(1)
+                    global_progress_bar.update(1)
+
+                    global_step += 1
+
+                logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
+
+                accelerator.log(logs, step=global_step)
+
+                local_progress_bar.set_postfix(**logs)
+
+                if global_step >= args.max_train_steps:
+                    break
+
+            train_loss /= len(train_dataloader)
+
+            accelerator.wait_for_everyone()
+
+            text_encoder.eval()
+            val_loss = 0.0
+
+            with torch.inference_mode():
+                for step, batch in enumerate(val_dataloader):
+                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
+                    latents = latents * 0.18215
+
+                    noise = torch.randn_like(latents)
+                    bsz = latents.shape[0]
+                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
+                                              (bsz,), device=latents.device)
+                    timesteps = timesteps.long()
+
+                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
+                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
+
+                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+                    # Get the target for loss depending on the prediction type
+                    if noise_scheduler.config.prediction_type == "epsilon":
+                        target = noise
+                    elif noise_scheduler.config.prediction_type == "v_prediction":
+                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                    else:
+                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+                    loss = loss.detach().item()
+                    val_loss += loss
+
+                    if accelerator.sync_gradients:
+                        local_progress_bar.update(1)
+                        global_progress_bar.update(1)
+
+                    logs = {"val/loss": loss}
+                    local_progress_bar.set_postfix(**logs)
+
+            val_loss /= len(val_dataloader)
+
+            accelerator.log({"val/loss": val_loss}, step=global_step)
+
+            local_progress_bar.clear()
+            global_progress_bar.clear()
+
+            if accelerator.is_main_process:
+                if min_val_loss > val_loss:
+                    accelerator.print(
+                        f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
+                    checkpointer.checkpoint(global_step + global_step_offset, "milestone")
+                    min_val_loss = val_loss
+
+                if (epoch + 1) % args.checkpoint_frequency == 0:
+                    checkpointer.checkpoint(global_step + global_step_offset, "training")
+                    save_args(basepath, args, {
+                        "global_step": global_step + global_step_offset
+                    })
+
+                if (epoch + 1) % args.sample_frequency == 0:
+                    checkpointer.save_samples(
+                        global_step + global_step_offset,
+                        args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
+
+        # Create the pipeline using using the trained modules and save it.
+        if accelerator.is_main_process:
+            print("Finished! Saving final checkpoint and resume state.")
+            checkpointer.checkpoint(global_step + global_step_offset, "end")
+            save_args(basepath, args, {
+                "global_step": global_step + global_step_offset
+            })
+            accelerator.end_training()
+
+    except KeyboardInterrupt:
+        if accelerator.is_main_process:
+            print("Interrupted, saving checkpoint and resume state...")
+            checkpointer.checkpoint(global_step + global_step_offset, "end")
+            save_args(basepath, args, {
+                "global_step": global_step + global_step_offset
+            })
+            accelerator.end_training()
+        quit()
+
+
+if __name__ == "__main__":
+    main()
-- 
cgit v1.2.3-70-g09d2