summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/dreambooth/csv.py177
-rw-r--r--data/dreambooth/prompt.py16
2 files changed, 193 insertions, 0 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
new file mode 100644
index 0000000..04df4c6
--- /dev/null
+++ b/data/dreambooth/csv.py
@@ -0,0 +1,177 @@
1import os
2import pandas as pd
3from pathlib import Path
4import PIL
5import pytorch_lightning as pl
6from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split
8from torchvision import transforms
9
10
11class CSVDataModule(pl.LightningDataModule):
12 def __init__(self,
13 batch_size,
14 data_root,
15 tokenizer,
16 instance_prompt,
17 class_data_root=None,
18 class_prompt=None,
19 size=512,
20 repeats=100,
21 interpolation="bicubic",
22 identifier="*",
23 center_crop=False,
24 collate_fn=None):
25 super().__init__()
26
27 self.data_root = data_root
28 self.tokenizer = tokenizer
29 self.instance_prompt = instance_prompt
30 self.class_data_root = class_data_root
31 self.class_prompt = class_prompt
32 self.size = size
33 self.repeats = repeats
34 self.identifier = identifier
35 self.center_crop = center_crop
36 self.interpolation = interpolation
37 self.collate_fn = collate_fn
38 self.batch_size = batch_size
39
40 def prepare_data(self):
41 metadata = pd.read_csv(f'{self.data_root}/list.csv')
42 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
43 captions = [caption for caption in metadata['caption'].values]
44 skips = [skip for skip in metadata['skip'].values]
45 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
46
47 def setup(self, stage=None):
48 train_set_size = int(len(self.data_full) * 0.8)
49 valid_set_size = len(self.data_full) - train_set_size
50 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
51
52 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
53 class_data_root=self.class_data_root,
54 class_prompt=self.class_prompt, size=self.size, repeats=self.repeats,
55 interpolation=self.interpolation, identifier=self.identifier,
56 center_crop=self.center_crop)
57 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root,
59 class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation,
60 identifier=self.identifier, center_crop=self.center_crop)
61 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
62 shuffle=True, collate_fn=self.collate_fn)
63 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
64
65 def train_dataloader(self):
66 return self.train_dataloader_
67
68 def val_dataloader(self):
69 return self.val_dataloader_
70
71
72class CSVDataset(Dataset):
73 def __init__(self,
74 data,
75 tokenizer,
76 instance_prompt,
77 class_data_root=None,
78 class_prompt=None,
79 size=512,
80 repeats=1,
81 interpolation="bicubic",
82 identifier="*",
83 center_crop=False,
84 ):
85
86 self.data = data
87 self.tokenizer = tokenizer
88 self.instance_prompt = instance_prompt
89
90 self.num_instance_images = len(self.data)
91 self._length = self.num_instance_images * repeats
92
93 self.identifier = identifier
94
95 if class_data_root is not None:
96 self.class_data_root = Path(class_data_root)
97 self.class_data_root.mkdir(parents=True, exist_ok=True)
98
99 self.class_images = list(Path(class_data_root).iterdir())
100 self.num_class_images = len(self.class_images)
101 self._length = max(self.num_class_images, self.num_instance_images)
102
103 self.class_prompt = class_prompt
104 else:
105 self.class_data_root = None
106
107 self.interpolation = {"linear": PIL.Image.LINEAR,
108 "bilinear": PIL.Image.BILINEAR,
109 "bicubic": PIL.Image.BICUBIC,
110 "lanczos": PIL.Image.LANCZOS,
111 }[interpolation]
112 self.image_transforms = transforms.Compose(
113 [
114 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
115 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
116 transforms.ToTensor(),
117 transforms.Normalize([0.5], [0.5]),
118 ]
119 )
120
121 self.cache = {}
122
123 def __len__(self):
124 return self._length
125
126 def get_example(self, i):
127 image_path, text = self.data[i % self.num_instance_images]
128
129 if image_path in self.cache:
130 return self.cache[image_path]
131
132 example = {}
133
134 instance_image = Image.open(image_path)
135 if not instance_image.mode == "RGB":
136 instance_image = instance_image.convert("RGB")
137
138 text = text.format(self.identifier)
139
140 example["prompts"] = text
141 example["instance_images"] = instance_image
142 example["instance_prompt_ids"] = self.tokenizer(
143 self.instance_prompt,
144 padding="do_not_pad",
145 truncation=True,
146 max_length=self.tokenizer.model_max_length,
147 ).input_ids
148
149 if self.class_data_root:
150 class_image = Image.open(self.class_images[i % self.num_class_images])
151 if not class_image.mode == "RGB":
152 class_image = class_image.convert("RGB")
153
154 example["class_images"] = class_image
155 example["class_prompt_ids"] = self.tokenizer(
156 self.class_prompt,
157 padding="do_not_pad",
158 truncation=True,
159 max_length=self.tokenizer.model_max_length,
160 ).input_ids
161
162 self.cache[image_path] = example
163 return example
164
165 def __getitem__(self, i):
166 example = {}
167 unprocessed_example = self.get_example(i)
168
169 example["prompts"] = unprocessed_example["prompts"]
170 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
171 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
172
173 if self.class_data_root:
174 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
175 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
176
177 return example
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
new file mode 100644
index 0000000..34f510d
--- /dev/null
+++ b/data/dreambooth/prompt.py
@@ -0,0 +1,16 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples):
6 self.prompt = prompt
7 self.num_samples = num_samples
8
9 def __len__(self):
10 return self.num_samples
11
12 def __getitem__(self, index):
13 example = {}
14 example["prompt"] = self.prompt
15 example["index"] = index
16 return example