summaryrefslogtreecommitdiffstats
path: root/data/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'data/textual_inversion')
-rw-r--r--data/textual_inversion/csv.py134
1 files changed, 134 insertions, 0 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
new file mode 100644
index 0000000..38ffb6f
--- /dev/null
+++ b/data/textual_inversion/csv.py
@@ -0,0 +1,134 @@
1import os
2import numpy as np
3import pandas as pd
4import random
5import PIL
6import pytorch_lightning as pl
7from PIL import Image
8import torch
9from torch.utils.data import Dataset, DataLoader, random_split
10from torchvision import transforms
11
12
13class CSVDataModule(pl.LightningDataModule):
14 def __init__(self,
15 batch_size,
16 data_root,
17 tokenizer,
18 size=512,
19 repeats=100,
20 interpolation="bicubic",
21 placeholder_token="*",
22 flip_p=0.5,
23 center_crop=False):
24 super().__init__()
25
26 self.data_root = data_root
27 self.tokenizer = tokenizer
28 self.size = size
29 self.repeats = repeats
30 self.placeholder_token = placeholder_token
31 self.center_crop = center_crop
32 self.flip_p = flip_p
33 self.interpolation = interpolation
34
35 self.batch_size = batch_size
36
37 def prepare_data(self):
38 metadata = pd.read_csv(f'{self.data_root}/list.csv')
39 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
40 captions = [caption for caption in metadata['caption'].values]
41 skips = [skip for skip in metadata['skip'].values]
42 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
43
44 def setup(self, stage=None):
45 train_set_size = int(len(self.data_full) * 0.8)
46 valid_set_size = len(self.data_full) - train_set_size
47 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
48
49 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
50 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
51 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
52 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
53 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
54 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
55
56 def train_dataloader(self):
57 return self.train_dataloader_
58
59 def val_dataloader(self):
60 return self.val_dataloader_
61
62
63class CSVDataset(Dataset):
64 def __init__(self,
65 data,
66 tokenizer,
67 size=512,
68 repeats=1,
69 interpolation="bicubic",
70 flip_p=0.5,
71 placeholder_token="*",
72 center_crop=False,
73 ):
74
75 self.data = data
76 self.tokenizer = tokenizer
77
78 self.num_images = len(self.data)
79 self._length = self.num_images * repeats
80
81 self.placeholder_token = placeholder_token
82
83 self.interpolation = {"linear": PIL.Image.LINEAR,
84 "bilinear": PIL.Image.BILINEAR,
85 "bicubic": PIL.Image.BICUBIC,
86 "lanczos": PIL.Image.LANCZOS,
87 }[interpolation]
88 self.image_transforms = transforms.Compose(
89 [
90 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
91 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
92 transforms.ToTensor(),
93 transforms.Normalize([0.5], [0.5]),
94 ]
95 )
96
97 self.cache = {}
98
99 def __len__(self):
100 return self._length
101
102 def get_example(self, i, flipped):
103 image_path, text = self.data[i % self.num_images]
104
105 if image_path in self.cache:
106 return self.cache[image_path]
107
108 example = {}
109 image = Image.open(image_path)
110 if not image.mode == "RGB":
111 image = image.convert("RGB")
112 image = self.image_transforms(image)
113
114 text = text.format(self.placeholder_token)
115
116 example["prompt"] = text
117 example["input_ids"] = self.tokenizer(
118 text,
119 padding="max_length",
120 truncation=True,
121 max_length=self.tokenizer.model_max_length,
122 return_tensors="pt",
123 ).input_ids[0]
124
125 example["key"] = "-".join([image_path, "-", str(flipped)])
126 example["pixel_values"] = image
127
128 self.cache[image_path] = example
129 return example
130
131 def __getitem__(self, i):
132 flipped = random.choice([False, True])
133 example = self.get_example(i, flipped)
134 return example