diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 145 |
1 files changed, 0 insertions, 145 deletions
diff --git a/data.py b/data.py deleted file mode 100644 index 0d1e96e..0000000 --- a/data.py +++ /dev/null | |||
@@ -1,145 +0,0 @@ | |||
1 | import os | ||
2 | import numpy as np | ||
3 | import pandas as pd | ||
4 | import random | ||
5 | import PIL | ||
6 | import pytorch_lightning as pl | ||
7 | from PIL import Image | ||
8 | import torch | ||
9 | from torch.utils.data import Dataset, DataLoader, random_split | ||
10 | from torchvision import transforms | ||
11 | |||
12 | |||
13 | class CSVDataModule(pl.LightningDataModule): | ||
14 | def __init__(self, | ||
15 | batch_size, | ||
16 | data_root, | ||
17 | tokenizer, | ||
18 | size=512, | ||
19 | repeats=100, | ||
20 | interpolation="bicubic", | ||
21 | placeholder_token="*", | ||
22 | flip_p=0.5, | ||
23 | center_crop=False): | ||
24 | super().__init__() | ||
25 | |||
26 | self.data_root = data_root | ||
27 | self.tokenizer = tokenizer | ||
28 | self.size = size | ||
29 | self.repeats = repeats | ||
30 | self.placeholder_token = placeholder_token | ||
31 | self.center_crop = center_crop | ||
32 | self.flip_p = flip_p | ||
33 | self.interpolation = interpolation | ||
34 | |||
35 | self.batch_size = batch_size | ||
36 | |||
37 | def prepare_data(self): | ||
38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | ||
39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | ||
40 | captions = [caption for caption in metadata['caption'].values] | ||
41 | skips = [skip for skip in metadata['skip'].values] | ||
42 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | ||
43 | |||
44 | def setup(self, stage=None): | ||
45 | train_set_size = int(len(self.data_full) * 0.8) | ||
46 | valid_set_size = len(self.data_full) - train_set_size | ||
47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | ||
48 | |||
49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | ||
50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | ||
52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | ||
55 | |||
56 | def train_dataloader(self): | ||
57 | return self.train_dataloader_ | ||
58 | |||
59 | def val_dataloader(self): | ||
60 | return self.val_dataloader_ | ||
61 | |||
62 | |||
63 | class CSVDataset(Dataset): | ||
64 | def __init__(self, | ||
65 | data, | ||
66 | tokenizer, | ||
67 | size=512, | ||
68 | repeats=1, | ||
69 | interpolation="bicubic", | ||
70 | flip_p=0.5, | ||
71 | placeholder_token="*", | ||
72 | center_crop=False, | ||
73 | ): | ||
74 | |||
75 | self.data = data | ||
76 | self.tokenizer = tokenizer | ||
77 | |||
78 | self.num_images = len(self.data) | ||
79 | self._length = self.num_images * repeats | ||
80 | |||
81 | self.placeholder_token = placeholder_token | ||
82 | |||
83 | self.size = size | ||
84 | self.center_crop = center_crop | ||
85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
86 | "bilinear": PIL.Image.BILINEAR, | ||
87 | "bicubic": PIL.Image.BICUBIC, | ||
88 | "lanczos": PIL.Image.LANCZOS, | ||
89 | }[interpolation] | ||
90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
91 | |||
92 | self.cache = {} | ||
93 | |||
94 | def __len__(self): | ||
95 | return self._length | ||
96 | |||
97 | def get_example(self, i, flipped): | ||
98 | image_path, text = self.data[i % self.num_images] | ||
99 | |||
100 | if image_path in self.cache: | ||
101 | return self.cache[image_path] | ||
102 | |||
103 | example = {} | ||
104 | image = Image.open(image_path) | ||
105 | |||
106 | if not image.mode == "RGB": | ||
107 | image = image.convert("RGB") | ||
108 | |||
109 | text = text.format(self.placeholder_token) | ||
110 | |||
111 | example["prompt"] = text | ||
112 | example["input_ids"] = self.tokenizer( | ||
113 | text, | ||
114 | padding="max_length", | ||
115 | truncation=True, | ||
116 | max_length=self.tokenizer.model_max_length, | ||
117 | return_tensors="pt", | ||
118 | ).input_ids[0] | ||
119 | |||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
138 | |||
139 | self.cache[image_path] = example | ||
140 | return example | ||
141 | |||
142 | def __getitem__(self, i): | ||
143 | flipped = random.choice([False, True]) | ||
144 | example = self.get_example(i, flipped) | ||
145 | return example | ||