summaryrefslogtreecommitdiffstats
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/data.py b/data.py
new file mode 100644
index 0000000..0d1e96e
--- /dev/null
+++ b/data.py
@@ -0,0 +1,145 @@
1import os
2import numpy as np
3import pandas as pd
4import random
5import PIL
6import pytorch_lightning as pl
7from PIL import Image
8import torch
9from torch.utils.data import Dataset, DataLoader, random_split
10from torchvision import transforms
11
12
13class CSVDataModule(pl.LightningDataModule):
14 def __init__(self,
15 batch_size,
16 data_root,
17 tokenizer,
18 size=512,
19 repeats=100,
20 interpolation="bicubic",
21 placeholder_token="*",
22 flip_p=0.5,
23 center_crop=False):
24 super().__init__()
25
26 self.data_root = data_root
27 self.tokenizer = tokenizer
28 self.size = size
29 self.repeats = repeats
30 self.placeholder_token = placeholder_token
31 self.center_crop = center_crop
32 self.flip_p = flip_p
33 self.interpolation = interpolation
34
35 self.batch_size = batch_size
36
37 def prepare_data(self):
38 metadata = pd.read_csv(f'{self.data_root}/list.csv')
39 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
40 captions = [caption for caption in metadata['caption'].values]
41 skips = [skip for skip in metadata['skip'].values]
42 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
43
44 def setup(self, stage=None):
45 train_set_size = int(len(self.data_full) * 0.8)
46 valid_set_size = len(self.data_full) - train_set_size
47 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
48
49 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
50 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
51 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
52 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
53 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
54 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
55
56 def train_dataloader(self):
57 return self.train_dataloader_
58
59 def val_dataloader(self):
60 return self.val_dataloader_
61
62
63class CSVDataset(Dataset):
64 def __init__(self,
65 data,
66 tokenizer,
67 size=512,
68 repeats=1,
69 interpolation="bicubic",
70 flip_p=0.5,
71 placeholder_token="*",
72 center_crop=False,
73 ):
74
75 self.data = data
76 self.tokenizer = tokenizer
77
78 self.num_images = len(self.data)
79 self._length = self.num_images * repeats
80
81 self.placeholder_token = placeholder_token
82
83 self.size = size
84 self.center_crop = center_crop
85 self.interpolation = {"linear": PIL.Image.LINEAR,
86 "bilinear": PIL.Image.BILINEAR,
87 "bicubic": PIL.Image.BICUBIC,
88 "lanczos": PIL.Image.LANCZOS,
89 }[interpolation]
90 self.flip = transforms.RandomHorizontalFlip(p=flip_p)
91
92 self.cache = {}
93
94 def __len__(self):
95 return self._length
96
97 def get_example(self, i, flipped):
98 image_path, text = self.data[i % self.num_images]
99
100 if image_path in self.cache:
101 return self.cache[image_path]
102
103 example = {}
104 image = Image.open(image_path)
105
106 if not image.mode == "RGB":
107 image = image.convert("RGB")
108
109 text = text.format(self.placeholder_token)
110
111 example["prompt"] = text
112 example["input_ids"] = self.tokenizer(
113 text,
114 padding="max_length",
115 truncation=True,
116 max_length=self.tokenizer.model_max_length,
117 return_tensors="pt",
118 ).input_ids[0]
119
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
136 example["key"] = "-".join([image_path, "-", str(flipped)])
137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
138
139 self.cache[image_path] = example
140 return example
141
142 def __getitem__(self, i):
143 flipped = random.choice([False, True])
144 example = self.get_example(i, flipped)
145 return example