diff options
Diffstat (limited to 'data/textual_inversion')
-rw-r--r-- | data/textual_inversion/csv.py | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py new file mode 100644 index 0000000..38ffb6f --- /dev/null +++ b/data/textual_inversion/csv.py | |||
@@ -0,0 +1,134 @@ | |||
1 | import os | ||
2 | import numpy as np | ||
3 | import pandas as pd | ||
4 | import random | ||
5 | import PIL | ||
6 | import pytorch_lightning as pl | ||
7 | from PIL import Image | ||
8 | import torch | ||
9 | from torch.utils.data import Dataset, DataLoader, random_split | ||
10 | from torchvision import transforms | ||
11 | |||
12 | |||
13 | class CSVDataModule(pl.LightningDataModule): | ||
14 | def __init__(self, | ||
15 | batch_size, | ||
16 | data_root, | ||
17 | tokenizer, | ||
18 | size=512, | ||
19 | repeats=100, | ||
20 | interpolation="bicubic", | ||
21 | placeholder_token="*", | ||
22 | flip_p=0.5, | ||
23 | center_crop=False): | ||
24 | super().__init__() | ||
25 | |||
26 | self.data_root = data_root | ||
27 | self.tokenizer = tokenizer | ||
28 | self.size = size | ||
29 | self.repeats = repeats | ||
30 | self.placeholder_token = placeholder_token | ||
31 | self.center_crop = center_crop | ||
32 | self.flip_p = flip_p | ||
33 | self.interpolation = interpolation | ||
34 | |||
35 | self.batch_size = batch_size | ||
36 | |||
37 | def prepare_data(self): | ||
38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | ||
39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | ||
40 | captions = [caption for caption in metadata['caption'].values] | ||
41 | skips = [skip for skip in metadata['skip'].values] | ||
42 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | ||
43 | |||
44 | def setup(self, stage=None): | ||
45 | train_set_size = int(len(self.data_full) * 0.8) | ||
46 | valid_set_size = len(self.data_full) - train_set_size | ||
47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | ||
48 | |||
49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | ||
50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | ||
52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | ||
55 | |||
56 | def train_dataloader(self): | ||
57 | return self.train_dataloader_ | ||
58 | |||
59 | def val_dataloader(self): | ||
60 | return self.val_dataloader_ | ||
61 | |||
62 | |||
63 | class CSVDataset(Dataset): | ||
64 | def __init__(self, | ||
65 | data, | ||
66 | tokenizer, | ||
67 | size=512, | ||
68 | repeats=1, | ||
69 | interpolation="bicubic", | ||
70 | flip_p=0.5, | ||
71 | placeholder_token="*", | ||
72 | center_crop=False, | ||
73 | ): | ||
74 | |||
75 | self.data = data | ||
76 | self.tokenizer = tokenizer | ||
77 | |||
78 | self.num_images = len(self.data) | ||
79 | self._length = self.num_images * repeats | ||
80 | |||
81 | self.placeholder_token = placeholder_token | ||
82 | |||
83 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
84 | "bilinear": PIL.Image.BILINEAR, | ||
85 | "bicubic": PIL.Image.BICUBIC, | ||
86 | "lanczos": PIL.Image.LANCZOS, | ||
87 | }[interpolation] | ||
88 | self.image_transforms = transforms.Compose( | ||
89 | [ | ||
90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
92 | transforms.ToTensor(), | ||
93 | transforms.Normalize([0.5], [0.5]), | ||
94 | ] | ||
95 | ) | ||
96 | |||
97 | self.cache = {} | ||
98 | |||
99 | def __len__(self): | ||
100 | return self._length | ||
101 | |||
102 | def get_example(self, i, flipped): | ||
103 | image_path, text = self.data[i % self.num_images] | ||
104 | |||
105 | if image_path in self.cache: | ||
106 | return self.cache[image_path] | ||
107 | |||
108 | example = {} | ||
109 | image = Image.open(image_path) | ||
110 | if not image.mode == "RGB": | ||
111 | image = image.convert("RGB") | ||
112 | image = self.image_transforms(image) | ||
113 | |||
114 | text = text.format(self.placeholder_token) | ||
115 | |||
116 | example["prompt"] = text | ||
117 | example["input_ids"] = self.tokenizer( | ||
118 | text, | ||
119 | padding="max_length", | ||
120 | truncation=True, | ||
121 | max_length=self.tokenizer.model_max_length, | ||
122 | return_tensors="pt", | ||
123 | ).input_ids[0] | ||
124 | |||
125 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
126 | example["pixel_values"] = image | ||
127 | |||
128 | self.cache[image_path] = example | ||
129 | return example | ||
130 | |||
131 | def __getitem__(self, i): | ||
132 | flipped = random.choice([False, True]) | ||
133 | example = self.get_example(i, flipped) | ||
134 | return example | ||