summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py181
1 files changed, 0 insertions, 181 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
deleted file mode 100644
index abd329d..0000000
--- a/data/dreambooth/csv.py
+++ /dev/null
@@ -1,181 +0,0 @@
1import math
2import os
3import pandas as pd
4from pathlib import Path
5import pytorch_lightning as pl
6from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split
8from torchvision import transforms
9
10
11class CSVDataModule(pl.LightningDataModule):
12 def __init__(self,
13 batch_size,
14 data_file,
15 tokenizer,
16 instance_identifier,
17 class_identifier=None,
18 class_subdir="db_cls",
19 size=512,
20 repeats=100,
21 interpolation="bicubic",
22 center_crop=False,
23 valid_set_size=None,
24 generator=None,
25 collate_fn=None):
26 super().__init__()
27
28 self.data_file = Path(data_file)
29
30 if not self.data_file.is_file():
31 raise ValueError("data_file must be a file")
32
33 self.data_root = self.data_file.parent
34 self.class_root = self.data_root.joinpath(class_subdir)
35 self.class_root.mkdir(parents=True, exist_ok=True)
36
37 self.tokenizer = tokenizer
38 self.instance_identifier = instance_identifier
39 self.class_identifier = class_identifier
40 self.size = size
41 self.repeats = repeats
42 self.center_crop = center_crop
43 self.interpolation = interpolation
44 self.valid_set_size = valid_set_size
45 self.generator = generator
46 self.collate_fn = collate_fn
47 self.batch_size = batch_size
48
49 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file)
51 instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values]
52 class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values]
53 prompts = metadata['prompt'].values
54 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths)
55 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths)
56 self.data = [(i, c, p, n)
57 for i, c, p, n, s
58 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
59 if s != "x"]
60
61 def setup(self, stage=None):
62 valid_set_size = int(len(self.data) * 0.2)
63 if self.valid_set_size:
64 valid_set_size = min(valid_set_size, self.valid_set_size)
65 valid_set_size = max(valid_set_size, 1)
66 train_set_size = len(self.data) - valid_set_size
67
68 self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator)
69
70 train_dataset = CSVDataset(self.data_train, self.tokenizer,
71 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
72 size=self.size, interpolation=self.interpolation,
73 center_crop=self.center_crop, repeats=self.repeats)
74 val_dataset = CSVDataset(self.data_val, self.tokenizer,
75 instance_identifier=self.instance_identifier,
76 size=self.size, interpolation=self.interpolation,
77 center_crop=self.center_crop, repeats=self.repeats)
78 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True,
79 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
80 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True,
81 pin_memory=True, collate_fn=self.collate_fn)
82
83 def train_dataloader(self):
84 return self.train_dataloader_
85
86 def val_dataloader(self):
87 return self.val_dataloader_
88
89
90class CSVDataset(Dataset):
91 def __init__(self,
92 data,
93 tokenizer,
94 instance_identifier,
95 class_identifier=None,
96 size=512,
97 repeats=1,
98 interpolation="bicubic",
99 center_crop=False,
100 ):
101
102 self.data = data
103 self.tokenizer = tokenizer
104 self.instance_identifier = instance_identifier
105 self.class_identifier = class_identifier
106 self.cache = {}
107
108 self.num_instance_images = len(self.data)
109 self._length = self.num_instance_images * repeats
110
111 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
112 "bilinear": transforms.InterpolationMode.BILINEAR,
113 "bicubic": transforms.InterpolationMode.BICUBIC,
114 "lanczos": transforms.InterpolationMode.LANCZOS,
115 }[interpolation]
116 self.image_transforms = transforms.Compose(
117 [
118 transforms.Resize(size, interpolation=self.interpolation),
119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
120 transforms.RandomHorizontalFlip(),
121 transforms.ToTensor(),
122 transforms.Normalize([0.5], [0.5]),
123 ]
124 )
125
126 def __len__(self):
127 return self._length
128
129 def get_example(self, i):
130 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images]
131
132 if instance_image_path in self.cache:
133 return self.cache[instance_image_path]
134
135 example = {}
136
137 example["prompts"] = prompt
138 example["nprompts"] = nprompt
139
140 instance_image = Image.open(instance_image_path)
141 if not instance_image.mode == "RGB":
142 instance_image = instance_image.convert("RGB")
143
144 example["instance_images"] = instance_image
145 example["instance_prompt_ids"] = self.tokenizer(
146 prompt.format(self.instance_identifier),
147 padding="do_not_pad",
148 truncation=True,
149 max_length=self.tokenizer.model_max_length,
150 ).input_ids
151
152 if self.class_identifier is not None:
153 class_image = Image.open(class_image_path)
154 if not class_image.mode == "RGB":
155 class_image = class_image.convert("RGB")
156
157 example["class_images"] = class_image
158 example["class_prompt_ids"] = self.tokenizer(
159 prompt.format(self.class_identifier),
160 padding="do_not_pad",
161 truncation=True,
162 max_length=self.tokenizer.model_max_length,
163 ).input_ids
164
165 self.cache[instance_image_path] = example
166 return example
167
168 def __getitem__(self, i):
169 example = {}
170 unprocessed_example = self.get_example(i)
171
172 example["prompts"] = unprocessed_example["prompts"]
173 example["nprompts"] = unprocessed_example["nprompts"]
174 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
175 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
176
177 if self.class_identifier is not None:
178 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
179 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
180
181 return example