From 73fe0a75cd08244f91d1baea7b63b42f9e4be08c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 12:39:43 +0200 Subject: Added Dreambooth training script --- .gitignore | 1 + data/dreambooth/csv.py | 177 ++++++++++ data/dreambooth/prompt.py | 16 + dreambooth.py | 825 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1019 insertions(+) create mode 100644 data/dreambooth/csv.py create mode 100644 data/dreambooth/prompt.py create mode 100644 dreambooth.py diff --git a/.gitignore b/.gitignore index a8893c3..91a5e07 100644 --- a/.gitignore +++ b/.gitignore @@ -160,5 +160,6 @@ cython_debug/ #.idea/ text-inversion-model/ +dreambooth-model/ conf*.json v1-inference.yaml* diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py new file mode 100644 index 0000000..04df4c6 --- /dev/null +++ b/data/dreambooth/csv.py @@ -0,0 +1,177 @@ +import os +import pandas as pd +from pathlib import Path +import PIL +import pytorch_lightning as pl +from PIL import Image +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_root, + tokenizer, + instance_prompt, + class_data_root=None, + class_prompt=None, + size=512, + repeats=100, + interpolation="bicubic", + identifier="*", + center_crop=False, + collate_fn=None): + super().__init__() + + self.data_root = data_root + self.tokenizer = tokenizer + self.instance_prompt = instance_prompt + self.class_data_root = class_data_root + self.class_prompt = class_prompt + self.size = size + self.repeats = repeats + self.identifier = identifier + self.center_crop = center_crop + self.interpolation = interpolation + self.collate_fn = collate_fn + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(f'{self.data_root}/list.csv') + image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + captions = [caption for caption in metadata['caption'].values] + skips = [skip for skip in metadata['skip'].values] + self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + + def setup(self, stage=None): + train_set_size = int(len(self.data_full) * 0.8) + valid_set_size = len(self.data_full) - train_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, + class_data_root=self.class_data_root, + class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, + interpolation=self.interpolation, identifier=self.identifier, + center_crop=self.center_crop) + val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, + class_data_root=self.class_data_root, + class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, + identifier=self.identifier, center_crop=self.center_crop) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, + shuffle=True, collate_fn=self.collate_fn) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + instance_prompt, + class_data_root=None, + class_prompt=None, + size=512, + repeats=1, + interpolation="bicubic", + identifier="*", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + self.instance_prompt = instance_prompt + + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats + + self.identifier = identifier + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + + self.class_images = list(Path(class_data_root).iterdir()) + self.num_class_images = len(self.class_images) + self._length = max(self.num_class_images, self.num_instance_images) + + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cache = {} + + def __len__(self): + return self._length + + def get_example(self, i): + image_path, text = self.data[i % self.num_instance_images] + + if image_path in self.cache: + return self.cache[image_path] + + example = {} + + instance_image = Image.open(image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + text = text.format(self.identifier) + + example["prompts"] = text + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images[i % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + + example["class_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + self.cache[image_path] = example + return example + + def __getitem__(self, i): + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) + example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + + if self.class_data_root: + example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) + example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + + return example diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py new file mode 100644 index 0000000..34f510d --- /dev/null +++ b/data/dreambooth/prompt.py @@ -0,0 +1,16 @@ +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example diff --git a/dreambooth.py b/dreambooth.py new file mode 100644 index 0000000..b6b3594 --- /dev/null +++ b/dreambooth.py @@ -0,0 +1,825 @@ +import argparse +import itertools +import math +import os +import datetime +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from pipelines.stable_diffusion.no_check import NoCheck +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from slugify import slugify +import json +import os + +from data.dreambooth.csv import CSVDataModule +from data.dreambooth.prompt import PromptDataset + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data." + ) + parser.add_argument( + "--identifier", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--repeats", + type=int, + default=100, + help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank" + ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=500, + help="How often to save a checkpoint and sample image", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=512, + help="Size of sample images", + ) + parser.add_argument( + "--stable_sample_batches", + type=int, + default=1, + help="Number of fixed seed sample batches to generate per checkpoint", + ) + parser.add_argument( + "--random_sample_batches", + type=int, + default=1, + help="Number of random seed sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=50, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifing the instance", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided intance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior perservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior perversation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script." + ) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify --train_data_dir") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if args.instance_prompt is None: + raise ValueError("You must specify --instance_prompt") + + if args.identifier is None: + raise ValueError("You must specify --identifier") + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify --class_data_dir") + if args.class_prompt is None: + raise ValueError("You must specify --class_prompt") + + return args + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols*w, i//cols*h)) + return grid + + +class Checkpointer: + def __init__( + self, + datamodule, + accelerator, + vae, + unet, + tokenizer, + text_encoder, + output_dir, + sample_image_size, + random_sample_batches, + sample_batch_size, + stable_sample_batches, + seed + ): + self.datamodule = datamodule + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.output_dir = output_dir + self.sample_image_size = sample_image_size + self.seed = seed + self.random_sample_batches = random_sample_batches + self.sample_batch_size = sample_batch_size + self.stable_sample_batches = stable_sample_batches + + @torch.no_grad() + def checkpoint(self): + print("Saving model...") + + unwrapped = self.accelerator.unwrap_model(self.unet) + pipeline = StableDiffusionPipeline( + text_encoder=self.text_encoder, + vae=self.vae, + unet=self.accelerator.unwrap_model(self.unet), + tokenizer=self.tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=NoCheck(), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.enable_attention_slicing() + pipeline.save_pretrained(f"{self.output_dir}/model.ckpt") + + del unwrapped + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @torch.no_grad() + def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): + samples_path = f"{self.output_dir}/samples/{mode}" + os.makedirs(samples_path, exist_ok=True) + checker = NoCheck() + + unwrapped = self.accelerator.unwrap_model(self.unet) + pipeline = StableDiffusionPipeline( + text_encoder=self.text_encoder, + vae=self.vae, + unet=unwrapped, + tokenizer=self.tokenizer, + scheduler=LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ), + safety_checker=NoCheck(), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(self.accelerator.device) + pipeline.enable_attention_slicing() + + data = { + "training": self.datamodule.train_dataloader(), + "validation": self.datamodule.val_dataloader(), + }[mode] + + if mode == "validation" and self.stable_sample_batches > 0 and step > 0: + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), + ) + + all_samples = [] + filename = f"stable_step_%d.png" % (step) + + data_enum = enumerate(data) + + # Generate and save stable samples + for i in range(0, self.stable_sample_batches): + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + latents=stable_latents[:len(prompt)], + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] + + all_samples += samples + del samples + + image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) + image_grid.save(f"{samples_path}/{filename}") + + del all_samples + del image_grid + del stable_latents + + all_samples = [] + filename = f"step_%d.png" % (step) + + data_enum = enumerate(data) + + # Generate and save random samples + for i in range(0, self.random_sample_batches): + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] + + all_samples += samples + del samples + + image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid.save(f"{samples_path}/{filename}") + + del all_samples + del image_grid + + del checker + del unwrapped + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" + os.makedirs(basepath, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with accelerator.autocast(): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path + '/tokenizer' + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path + '/text_encoder', + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path + '/vae', + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path + '/unet', + ) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # slice_size = unet.config.attention_head_dim // 2 + # unet.set_attention_slice(slice_size) + + # Freeze vae and unet + # freeze_params(vae.parameters()) + # freeze_params(text_encoder.parameters()) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + optimizer = torch.optim.AdamW( + unet.parameters(), # only optimize unet + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # TODO (patil-suraj): laod scheduler using args + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + ) + + def collate_fn(examples): + prompts = [example["prompts"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + + batch = { + "prompts": prompts, + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + datamodule = CSVDataModule( + data_root=args.train_data_dir, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + size=args.resolution, + identifier=args.identifier, + repeats=args.repeats, + center_crop=args.center_crop, + collate_fn=collate_fn) + + datamodule.prepare_data() + datamodule.setup() + + train_dataloader = datamodule.train_dataloader() + val_dataloader = datamodule.val_dataloader() + + checkpointer = Checkpointer( + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + random_sample_batches=args.random_sample_batches, + stable_sample_batches=args.stable_sample_batches, + seed=args.seed + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) + + # Move vae and unet to device + text_encoder.to(accelerator.device) + vae.to(accelerator.device) + + # Keep text_encoder and vae in eval mode as we don't train these + # text_encoder.eval() + # vae.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil( + args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + + global_step = 0 + min_val_loss = np.inf + + checkpointer.save_samples( + "validation", + 0, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Global steps") + + local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) + local_progress_bar.set_description("Steps") + + try: + for epoch in range(args.num_train_epochs): + local_progress_bar.reset() + + unet.train() + train_loss = 0.0 + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + with accelerator.autocast(): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, + (bsz,), device=latents.device).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + accelerator.backward(loss) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + loss = loss.detach().item() + train_loss += loss + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + local_progress_bar.update(1) + + global_step += 1 + + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + progress_bar.clear() + local_progress_bar.clear() + + checkpointer.save_samples( + "training", + global_step, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} + local_progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + train_loss /= len(train_dataloader) + + unet.eval() + val_loss = 0.0 + + for step, batch in enumerate(val_dataloader): + with torch.no_grad(), accelerator.autocast(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, + (bsz,), device=latents.device).long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + loss = loss.detach().item() + val_loss += loss + + if accelerator.sync_gradients: + progress_bar.update(1) + local_progress_bar.update(1) + + logs = {"mode": "validation", "loss": loss} + local_progress_bar.set_postfix(**logs) + + val_loss /= len(val_dataloader) + + accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + + progress_bar.clear() + local_progress_bar.clear() + + if min_val_loss > val_loss: + accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + min_val_loss = val_loss + + checkpointer.save_samples( + "validation", + global_step, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + print("Finished! Saving final checkpoint and resume state.") + checkpointer.checkpoint() + + accelerator.end_training() + + except KeyboardInterrupt: + if accelerator.is_main_process: + print("Interrupted, saving checkpoint and resume state...") + checkpointer.checkpoint() + accelerator.end_training() + quit() + + +if __name__ == "__main__": + main() -- cgit v1.2.3-70-g09d2