import os import pandas as pd from pathlib import Path import PIL import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms class CSVDataModule(pl.LightningDataModule): def __init__(self, batch_size, data_file, tokenizer, instance_prompt, class_data_root=None, class_prompt=None, size=512, repeats=100, interpolation="bicubic", identifier="*", center_crop=False, collate_fn=None): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.tokenizer = tokenizer self.instance_prompt = instance_prompt self.class_data_root = class_data_root self.class_prompt = class_prompt self.size = size self.repeats = repeats self.identifier = identifier self.center_crop = center_crop self.interpolation = interpolation self.collate_fn = collate_fn self.batch_size = batch_size def prepare_data(self): metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] captions = [caption for caption in metadata['caption'].values] skips = [skip for skip in metadata['skip'].values] self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] def setup(self, stage=None): train_set_size = int(len(self.data_full) * 0.8) valid_set_size = len(self.data_full) - train_set_size self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) def train_dataloader(self): return self.train_dataloader_ def val_dataloader(self): return self.val_dataloader_ class CSVDataset(Dataset): def __init__(self, data, tokenizer, instance_prompt, class_data_root=None, class_prompt=None, size=512, repeats=1, interpolation="bicubic", identifier="*", center_crop=False, ): self.data = data self.tokenizer = tokenizer self.instance_prompt = instance_prompt self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats self.identifier = identifier if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images = list(Path(class_data_root).iterdir()) self.num_class_images = len(self.class_images) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: self.class_data_root = None self.interpolation = {"linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) self.cache = {} def __len__(self): return self._length def get_example(self, i): image_path, text = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] example = {} instance_image = Image.open(image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") text = text.format(self.identifier) example["prompts"] = text example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.class_data_root: class_image = Image.open(self.class_images[i % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( self.class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids self.cache[image_path] = example return example def __getitem__(self, i): example = {} unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] if self.class_data_root: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] return example