import os import numpy as np import pandas as pd from pathlib import Path import math import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, data_file, tokenizer, size=512, repeats=100, interpolation="bicubic", placeholder_token="*", center_crop=False, valid_set_size=None, generator=None): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.tokenizer = tokenizer self.size = size self.repeats = repeats self.placeholder_token = placeholder_token self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator self.batch_size = batch_size def prepare_data(self): metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] prompts = metadata['prompt'].values nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) train_set_size = len(self.data_full) - valid_set_size self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__(self, data, tokenizer, size=512, repeats=1, interpolation="bicubic", placeholder_token="*", center_crop=False, batch_size=1, ): self.data = data self.tokenizer = tokenizer self.placeholder_token = placeholder_token self.batch_size = batch_size self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): image_path, prompt, nprompt = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] example = {} instance_image = Image.open(image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") prompt = prompt.format(self.placeholder_token) example["prompts"] = prompt example["nprompts"] = nprompt example["pixel_values"] = instance_image example["input_ids"] = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt", ).input_ids[0] self.cache[image_path] = example return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["nprompts"] = unprocessed_example["nprompts"] example["input_ids"] = unprocessed_example["input_ids"] example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) return example