diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 10:03:12 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 10:03:12 +0200 |
| commit | 5d2abb1749b5d2f2f22ad603b5c2bf9182864520 (patch) | |
| tree | d122d75322dff5cce3f2eb6cac0efe375320b9fd /data/textual_inversion | |
| parent | Use diffusers fork with Flash Attention (diff) | |
| download | textual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.tar.gz textual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.tar.bz2 textual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.zip | |
More cleanup
Diffstat (limited to 'data/textual_inversion')
| -rw-r--r-- | data/textual_inversion/csv.py | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py new file mode 100644 index 0000000..38ffb6f --- /dev/null +++ b/data/textual_inversion/csv.py | |||
| @@ -0,0 +1,134 @@ | |||
| 1 | import os | ||
| 2 | import numpy as np | ||
| 3 | import pandas as pd | ||
| 4 | import random | ||
| 5 | import PIL | ||
| 6 | import pytorch_lightning as pl | ||
| 7 | from PIL import Image | ||
| 8 | import torch | ||
| 9 | from torch.utils.data import Dataset, DataLoader, random_split | ||
| 10 | from torchvision import transforms | ||
| 11 | |||
| 12 | |||
| 13 | class CSVDataModule(pl.LightningDataModule): | ||
| 14 | def __init__(self, | ||
| 15 | batch_size, | ||
| 16 | data_root, | ||
| 17 | tokenizer, | ||
| 18 | size=512, | ||
| 19 | repeats=100, | ||
| 20 | interpolation="bicubic", | ||
| 21 | placeholder_token="*", | ||
| 22 | flip_p=0.5, | ||
| 23 | center_crop=False): | ||
| 24 | super().__init__() | ||
| 25 | |||
| 26 | self.data_root = data_root | ||
| 27 | self.tokenizer = tokenizer | ||
| 28 | self.size = size | ||
| 29 | self.repeats = repeats | ||
| 30 | self.placeholder_token = placeholder_token | ||
| 31 | self.center_crop = center_crop | ||
| 32 | self.flip_p = flip_p | ||
| 33 | self.interpolation = interpolation | ||
| 34 | |||
| 35 | self.batch_size = batch_size | ||
| 36 | |||
| 37 | def prepare_data(self): | ||
| 38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | ||
| 39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | ||
| 40 | captions = [caption for caption in metadata['caption'].values] | ||
| 41 | skips = [skip for skip in metadata['skip'].values] | ||
| 42 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | ||
| 43 | |||
| 44 | def setup(self, stage=None): | ||
| 45 | train_set_size = int(len(self.data_full) * 0.8) | ||
| 46 | valid_set_size = len(self.data_full) - train_set_size | ||
| 47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | ||
| 48 | |||
| 49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | ||
| 50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
| 51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | ||
| 52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
| 53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
| 54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | ||
| 55 | |||
| 56 | def train_dataloader(self): | ||
| 57 | return self.train_dataloader_ | ||
| 58 | |||
| 59 | def val_dataloader(self): | ||
| 60 | return self.val_dataloader_ | ||
| 61 | |||
| 62 | |||
| 63 | class CSVDataset(Dataset): | ||
| 64 | def __init__(self, | ||
| 65 | data, | ||
| 66 | tokenizer, | ||
| 67 | size=512, | ||
| 68 | repeats=1, | ||
| 69 | interpolation="bicubic", | ||
| 70 | flip_p=0.5, | ||
| 71 | placeholder_token="*", | ||
| 72 | center_crop=False, | ||
| 73 | ): | ||
| 74 | |||
| 75 | self.data = data | ||
| 76 | self.tokenizer = tokenizer | ||
| 77 | |||
| 78 | self.num_images = len(self.data) | ||
| 79 | self._length = self.num_images * repeats | ||
| 80 | |||
| 81 | self.placeholder_token = placeholder_token | ||
| 82 | |||
| 83 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
| 84 | "bilinear": PIL.Image.BILINEAR, | ||
| 85 | "bicubic": PIL.Image.BICUBIC, | ||
| 86 | "lanczos": PIL.Image.LANCZOS, | ||
| 87 | }[interpolation] | ||
| 88 | self.image_transforms = transforms.Compose( | ||
| 89 | [ | ||
| 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
| 91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
| 92 | transforms.ToTensor(), | ||
| 93 | transforms.Normalize([0.5], [0.5]), | ||
| 94 | ] | ||
| 95 | ) | ||
| 96 | |||
| 97 | self.cache = {} | ||
| 98 | |||
| 99 | def __len__(self): | ||
| 100 | return self._length | ||
| 101 | |||
| 102 | def get_example(self, i, flipped): | ||
| 103 | image_path, text = self.data[i % self.num_images] | ||
| 104 | |||
| 105 | if image_path in self.cache: | ||
| 106 | return self.cache[image_path] | ||
| 107 | |||
| 108 | example = {} | ||
| 109 | image = Image.open(image_path) | ||
| 110 | if not image.mode == "RGB": | ||
| 111 | image = image.convert("RGB") | ||
| 112 | image = self.image_transforms(image) | ||
| 113 | |||
| 114 | text = text.format(self.placeholder_token) | ||
| 115 | |||
| 116 | example["prompt"] = text | ||
| 117 | example["input_ids"] = self.tokenizer( | ||
| 118 | text, | ||
| 119 | padding="max_length", | ||
| 120 | truncation=True, | ||
| 121 | max_length=self.tokenizer.model_max_length, | ||
| 122 | return_tensors="pt", | ||
| 123 | ).input_ids[0] | ||
| 124 | |||
| 125 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
| 126 | example["pixel_values"] = image | ||
| 127 | |||
| 128 | self.cache[image_path] = example | ||
| 129 | return example | ||
| 130 | |||
| 131 | def __getitem__(self, i): | ||
| 132 | flipped = random.choice([False, True]) | ||
| 133 | example = self.get_example(i, flipped) | ||
| 134 | return example | ||
