From 73fe0a75cd08244f91d1baea7b63b42f9e4be08c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 12:39:43 +0200 Subject: Added Dreambooth training script --- data/dreambooth/csv.py | 177 ++++++++++++++++++++++++++++++++++++++++++++++ data/dreambooth/prompt.py | 16 +++++ 2 files changed, 193 insertions(+) create mode 100644 data/dreambooth/csv.py create mode 100644 data/dreambooth/prompt.py (limited to 'data/dreambooth') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py new file mode 100644 index 0000000..04df4c6 --- /dev/null +++ b/data/dreambooth/csv.py @@ -0,0 +1,177 @@ +import os +import pandas as pd +from pathlib import Path +import PIL +import pytorch_lightning as pl +from PIL import Image +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision import transforms + + +class CSVDataModule(pl.LightningDataModule): + def __init__(self, + batch_size, + data_root, + tokenizer, + instance_prompt, + class_data_root=None, + class_prompt=None, + size=512, + repeats=100, + interpolation="bicubic", + identifier="*", + center_crop=False, + collate_fn=None): + super().__init__() + + self.data_root = data_root + self.tokenizer = tokenizer + self.instance_prompt = instance_prompt + self.class_data_root = class_data_root + self.class_prompt = class_prompt + self.size = size + self.repeats = repeats + self.identifier = identifier + self.center_crop = center_crop + self.interpolation = interpolation + self.collate_fn = collate_fn + self.batch_size = batch_size + + def prepare_data(self): + metadata = pd.read_csv(f'{self.data_root}/list.csv') + image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + captions = [caption for caption in metadata['caption'].values] + skips = [skip for skip in metadata['skip'].values] + self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + + def setup(self, stage=None): + train_set_size = int(len(self.data_full) * 0.8) + valid_set_size = len(self.data_full) - train_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, + class_data_root=self.class_data_root, + class_prompt=self.class_prompt, size=self.size, repeats=self.repeats, + interpolation=self.interpolation, identifier=self.identifier, + center_crop=self.center_crop) + val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, + class_data_root=self.class_data_root, + class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, + identifier=self.identifier, center_crop=self.center_crop) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, + shuffle=True, collate_fn=self.collate_fn) + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) + + def train_dataloader(self): + return self.train_dataloader_ + + def val_dataloader(self): + return self.val_dataloader_ + + +class CSVDataset(Dataset): + def __init__(self, + data, + tokenizer, + instance_prompt, + class_data_root=None, + class_prompt=None, + size=512, + repeats=1, + interpolation="bicubic", + identifier="*", + center_crop=False, + ): + + self.data = data + self.tokenizer = tokenizer + self.instance_prompt = instance_prompt + + self.num_instance_images = len(self.data) + self._length = self.num_instance_images * repeats + + self.identifier = identifier + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + + self.class_images = list(Path(class_data_root).iterdir()) + self.num_class_images = len(self.class_images) + self._length = max(self.num_class_images, self.num_instance_images) + + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cache = {} + + def __len__(self): + return self._length + + def get_example(self, i): + image_path, text = self.data[i % self.num_instance_images] + + if image_path in self.cache: + return self.cache[image_path] + + example = {} + + instance_image = Image.open(image_path) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + + text = text.format(self.identifier) + + example["prompts"] = text + example["instance_images"] = instance_image + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images[i % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + + example["class_images"] = class_image + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + self.cache[image_path] = example + return example + + def __getitem__(self, i): + example = {} + unprocessed_example = self.get_example(i) + + example["prompts"] = unprocessed_example["prompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) + example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + + if self.class_data_root: + example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) + example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + + return example diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py new file mode 100644 index 0000000..34f510d --- /dev/null +++ b/data/dreambooth/prompt.py @@ -0,0 +1,16 @@ +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example -- cgit v1.2.3-70-g09d2