summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py162
1 files changed, 85 insertions, 77 deletions
diff --git a/data/csv.py b/data/csv.py
index dcaf7d3..8637ac1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,27 +1,38 @@
1import math
1import pandas as pd 2import pandas as pd
2from pathlib import Path 3from pathlib import Path
3import pytorch_lightning as pl 4import pytorch_lightning as pl
4from PIL import Image 5from PIL import Image
5from torch.utils.data import Dataset, DataLoader, random_split 6from torch.utils.data import Dataset, DataLoader, random_split
6from torchvision import transforms 7from torchvision import transforms
8from typing import NamedTuple, List
9
10
11class CSVDataItem(NamedTuple):
12 instance_image_path: Path
13 class_image_path: Path
14 prompt: str
15 nprompt: str
7 16
8 17
9class CSVDataModule(pl.LightningDataModule): 18class CSVDataModule(pl.LightningDataModule):
10 def __init__(self, 19 def __init__(
11 batch_size, 20 self,
12 data_file, 21 batch_size,
13 tokenizer, 22 data_file,
14 instance_identifier, 23 tokenizer,
15 class_identifier=None, 24 instance_identifier,
16 class_subdir="db_cls", 25 class_identifier=None,
17 num_class_images=2, 26 class_subdir="db_cls",
18 size=512, 27 num_class_images=100,
19 repeats=100, 28 size=512,
20 interpolation="bicubic", 29 repeats=100,
21 center_crop=False, 30 interpolation="bicubic",
22 valid_set_size=None, 31 center_crop=False,
23 generator=None, 32 valid_set_size=None,
24 collate_fn=None): 33 generator=None,
34 collate_fn=None
35 ):
25 super().__init__() 36 super().__init__()
26 37
27 self.data_file = Path(data_file) 38 self.data_file = Path(data_file)
@@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule):
46 self.collate_fn = collate_fn 57 self.collate_fn = collate_fn
47 self.batch_size = batch_size 58 self.batch_size = batch_size
48 59
60 def prepare_subdata(self, data, num_class_images=1):
61 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
62
63 return [
64 CSVDataItem(
65 self.data_root.joinpath(item.image),
66 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"),
67 item.prompt,
68 item.nprompt if "nprompt" in item else ""
69 )
70 for item in data
71 if "skip" not in item or item.skip != "x"
72 for i in range(image_multiplier)
73 ]
74
49 def prepare_data(self): 75 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 76 metadata = pd.read_csv(self.data_file)
51 instance_image_paths = [ 77 metadata = list(metadata.itertuples())
52 self.data_root.joinpath(f) 78 num_images = len(metadata)
53 for f in metadata['image'].values
54 for i in range(self.num_class_images)
55 ]
56 class_image_paths = [
57 self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}")
58 for f in metadata['image'].values
59 for i in range(self.num_class_images)
60 ]
61 prompts = [
62 prompt
63 for prompt in metadata['prompt'].values
64 for i in range(self.num_class_images)
65 ]
66 nprompts = [
67 nprompt
68 for nprompt in metadata['nprompt'].values
69 for i in range(self.num_class_images)
70 ] if 'nprompt' in metadata else [""] * len(instance_image_paths)
71 skips = [
72 skip
73 for skip in metadata['skip'].values
74 for i in range(self.num_class_images)
75 ] if 'skip' in metadata else [""] * len(instance_image_paths)
76 self.data = [
77 (i, c, p, n)
78 for i, c, p, n, s
79 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
80 if s != "x"
81 ]
82 79
83 def setup(self, stage=None): 80 valid_set_size = int(num_images * 0.2)
84 valid_set_size = int(len(self.data) * 0.2)
85 if self.valid_set_size: 81 if self.valid_set_size:
86 valid_set_size = min(valid_set_size, self.valid_set_size) 82 valid_set_size = min(valid_set_size, self.valid_set_size)
87 valid_set_size = max(valid_set_size, 1) 83 valid_set_size = max(valid_set_size, 1)
88 train_set_size = len(self.data) - valid_set_size 84 train_set_size = num_images - valid_set_size
89 85
90 self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) 86 data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator)
91 87
92 train_dataset = CSVDataset(self.data_train, self.tokenizer, 88 self.data_train = self.prepare_subdata(data_train, self.num_class_images)
89 self.data_val = self.prepare_subdata(data_val)
90
91 def setup(self, stage=None):
92 train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size,
93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
94 num_class_images=self.num_class_images, 94 num_class_images=self.num_class_images,
95 size=self.size, interpolation=self.interpolation, 95 size=self.size, interpolation=self.interpolation,
96 center_crop=self.center_crop, repeats=self.repeats) 96 center_crop=self.center_crop, repeats=self.repeats)
97 val_dataset = CSVDataset(self.data_val, self.tokenizer, 97 val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size,
98 instance_identifier=self.instance_identifier, 98 instance_identifier=self.instance_identifier,
99 size=self.size, interpolation=self.interpolation, 99 size=self.size, interpolation=self.interpolation,
100 center_crop=self.center_crop, repeats=self.repeats) 100 center_crop=self.center_crop, repeats=self.repeats)
101 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, 101 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
102 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 102 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
103 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, 103 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
104 pin_memory=True, collate_fn=self.collate_fn) 104 pin_memory=True, collate_fn=self.collate_fn)
105 105
106 def train_dataloader(self): 106 def train_dataloader(self):
@@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule):
111 111
112 112
113class CSVDataset(Dataset): 113class CSVDataset(Dataset):
114 def __init__(self, 114 def __init__(
115 data, 115 self,
116 tokenizer, 116 data: List[CSVDataItem],
117 instance_identifier, 117 tokenizer,
118 class_identifier=None, 118 instance_identifier,
119 num_class_images=2, 119 batch_size=1,
120 size=512, 120 class_identifier=None,
121 repeats=1, 121 num_class_images=0,
122 interpolation="bicubic", 122 size=512,
123 center_crop=False, 123 repeats=1,
124 ): 124 interpolation="bicubic",
125 center_crop=False,
126 ):
125 127
126 self.data = data 128 self.data = data
127 self.tokenizer = tokenizer 129 self.tokenizer = tokenizer
130 self.batch_size = batch_size
128 self.instance_identifier = instance_identifier 131 self.instance_identifier = instance_identifier
129 self.class_identifier = class_identifier 132 self.class_identifier = class_identifier
130 self.num_class_images = num_class_images 133 self.num_class_images = num_class_images
131 self.cache = {} 134 self.cache = {}
135 self.image_cache = {}
132 136
133 self.num_instance_images = len(self.data) 137 self.num_instance_images = len(self.data)
134 self._length = self.num_instance_images * repeats 138 self._length = self.num_instance_images * repeats
@@ -149,46 +153,50 @@ class CSVDataset(Dataset):
149 ) 153 )
150 154
151 def __len__(self): 155 def __len__(self):
152 return self._length 156 return math.ceil(self._length / self.batch_size) * self.batch_size
153 157
154 def get_example(self, i): 158 def get_example(self, i):
155 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] 159 item = self.data[i % self.num_instance_images]
156 cache_key = f"{instance_image_path}_{class_image_path}" 160 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
157 161
158 if cache_key in self.cache: 162 if cache_key in self.cache:
159 return self.cache[cache_key] 163 return self.cache[cache_key]
160 164
161 example = {} 165 example = {}
162 166
163 example["prompts"] = prompt 167 example["prompts"] = item.prompt
164 example["nprompts"] = nprompt 168 example["nprompts"] = item.nprompt
165 169
166 instance_image = Image.open(instance_image_path) 170 if item.instance_image_path in self.image_cache:
167 if not instance_image.mode == "RGB": 171 instance_image = self.image_cache[item.instance_image_path]
168 instance_image = instance_image.convert("RGB") 172 else:
173 instance_image = Image.open(item.instance_image_path)
174 if not instance_image.mode == "RGB":
175 instance_image = instance_image.convert("RGB")
176 self.image_cache[item.instance_image_path] = instance_image
169 177
170 example["instance_images"] = instance_image 178 example["instance_images"] = instance_image
171 example["instance_prompt_ids"] = self.tokenizer( 179 example["instance_prompt_ids"] = self.tokenizer(
172 prompt.format(self.instance_identifier), 180 item.prompt.format(self.instance_identifier),
173 padding="do_not_pad", 181 padding="do_not_pad",
174 truncation=True, 182 truncation=True,
175 max_length=self.tokenizer.model_max_length, 183 max_length=self.tokenizer.model_max_length,
176 ).input_ids 184 ).input_ids
177 185
178 if self.num_class_images != 0: 186 if self.num_class_images != 0:
179 class_image = Image.open(class_image_path) 187 class_image = Image.open(item.class_image_path)
180 if not class_image.mode == "RGB": 188 if not class_image.mode == "RGB":
181 class_image = class_image.convert("RGB") 189 class_image = class_image.convert("RGB")
182 190
183 example["class_images"] = class_image 191 example["class_images"] = class_image
184 example["class_prompt_ids"] = self.tokenizer( 192 example["class_prompt_ids"] = self.tokenizer(
185 prompt.format(self.class_identifier), 193 item.prompt.format(self.class_identifier),
186 padding="do_not_pad", 194 padding="do_not_pad",
187 truncation=True, 195 truncation=True,
188 max_length=self.tokenizer.model_max_length, 196 max_length=self.tokenizer.model_max_length,
189 ).input_ids 197 ).input_ids
190 198
191 self.cache[instance_image_path] = example 199 self.cache[item.instance_image_path] = example
192 return example 200 return example
193 201
194 def __getitem__(self, i): 202 def __getitem__(self, i):