diff options
author | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 |
commit | 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch) | |
tree | f490b4794366e78f7b079eb04de1c7c00e17d34a /data | |
parent | Fix small details (diff) | |
download | textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2 textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip |
Update
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 162 |
1 files changed, 85 insertions, 77 deletions
diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,27 +1,38 @@ | |||
1 | import math | ||
1 | import pandas as pd | 2 | import pandas as pd |
2 | from pathlib import Path | 3 | from pathlib import Path |
3 | import pytorch_lightning as pl | 4 | import pytorch_lightning as pl |
4 | from PIL import Image | 5 | from PIL import Image |
5 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
6 | from torchvision import transforms | 7 | from torchvision import transforms |
8 | from typing import NamedTuple, List | ||
9 | |||
10 | |||
11 | class CSVDataItem(NamedTuple): | ||
12 | instance_image_path: Path | ||
13 | class_image_path: Path | ||
14 | prompt: str | ||
15 | nprompt: str | ||
7 | 16 | ||
8 | 17 | ||
9 | class CSVDataModule(pl.LightningDataModule): | 18 | class CSVDataModule(pl.LightningDataModule): |
10 | def __init__(self, | 19 | def __init__( |
11 | batch_size, | 20 | self, |
12 | data_file, | 21 | batch_size, |
13 | tokenizer, | 22 | data_file, |
14 | instance_identifier, | 23 | tokenizer, |
15 | class_identifier=None, | 24 | instance_identifier, |
16 | class_subdir="db_cls", | 25 | class_identifier=None, |
17 | num_class_images=2, | 26 | class_subdir="db_cls", |
18 | size=512, | 27 | num_class_images=100, |
19 | repeats=100, | 28 | size=512, |
20 | interpolation="bicubic", | 29 | repeats=100, |
21 | center_crop=False, | 30 | interpolation="bicubic", |
22 | valid_set_size=None, | 31 | center_crop=False, |
23 | generator=None, | 32 | valid_set_size=None, |
24 | collate_fn=None): | 33 | generator=None, |
34 | collate_fn=None | ||
35 | ): | ||
25 | super().__init__() | 36 | super().__init__() |
26 | 37 | ||
27 | self.data_file = Path(data_file) | 38 | self.data_file = Path(data_file) |
@@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): | |||
46 | self.collate_fn = collate_fn | 57 | self.collate_fn = collate_fn |
47 | self.batch_size = batch_size | 58 | self.batch_size = batch_size |
48 | 59 | ||
60 | def prepare_subdata(self, data, num_class_images=1): | ||
61 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
62 | |||
63 | return [ | ||
64 | CSVDataItem( | ||
65 | self.data_root.joinpath(item.image), | ||
66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | ||
67 | item.prompt, | ||
68 | item.nprompt if "nprompt" in item else "" | ||
69 | ) | ||
70 | for item in data | ||
71 | if "skip" not in item or item.skip != "x" | ||
72 | for i in range(image_multiplier) | ||
73 | ] | ||
74 | |||
49 | def prepare_data(self): | 75 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 76 | metadata = pd.read_csv(self.data_file) |
51 | instance_image_paths = [ | 77 | metadata = list(metadata.itertuples()) |
52 | self.data_root.joinpath(f) | 78 | num_images = len(metadata) |
53 | for f in metadata['image'].values | ||
54 | for i in range(self.num_class_images) | ||
55 | ] | ||
56 | class_image_paths = [ | ||
57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") | ||
58 | for f in metadata['image'].values | ||
59 | for i in range(self.num_class_images) | ||
60 | ] | ||
61 | prompts = [ | ||
62 | prompt | ||
63 | for prompt in metadata['prompt'].values | ||
64 | for i in range(self.num_class_images) | ||
65 | ] | ||
66 | nprompts = [ | ||
67 | nprompt | ||
68 | for nprompt in metadata['nprompt'].values | ||
69 | for i in range(self.num_class_images) | ||
70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
71 | skips = [ | ||
72 | skip | ||
73 | for skip in metadata['skip'].values | ||
74 | for i in range(self.num_class_images) | ||
75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
76 | self.data = [ | ||
77 | (i, c, p, n) | ||
78 | for i, c, p, n, s | ||
79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
80 | if s != "x" | ||
81 | ] | ||
82 | 79 | ||
83 | def setup(self, stage=None): | 80 | valid_set_size = int(num_images * 0.2) |
84 | valid_set_size = int(len(self.data) * 0.2) | ||
85 | if self.valid_set_size: | 81 | if self.valid_set_size: |
86 | valid_set_size = min(valid_set_size, self.valid_set_size) | 82 | valid_set_size = min(valid_set_size, self.valid_set_size) |
87 | valid_set_size = max(valid_set_size, 1) | 83 | valid_set_size = max(valid_set_size, 1) |
88 | train_set_size = len(self.data) - valid_set_size | 84 | train_set_size = num_images - valid_set_size |
89 | 85 | ||
90 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) | 86 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) |
91 | 87 | ||
92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 88 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) |
89 | self.data_val = self.prepare_subdata(data_val) | ||
90 | |||
91 | def setup(self, stage=None): | ||
92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | ||
93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
94 | num_class_images=self.num_class_images, | 94 | num_class_images=self.num_class_images, |
95 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
96 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, |
98 | instance_identifier=self.instance_identifier, | 98 | instance_identifier=self.instance_identifier, |
99 | size=self.size, interpolation=self.interpolation, | 99 | size=self.size, interpolation=self.interpolation, |
100 | center_crop=self.center_crop, repeats=self.repeats) | 100 | center_crop=self.center_crop, repeats=self.repeats) |
101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, | 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
104 | pin_memory=True, collate_fn=self.collate_fn) | 104 | pin_memory=True, collate_fn=self.collate_fn) |
105 | 105 | ||
106 | def train_dataloader(self): | 106 | def train_dataloader(self): |
@@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): | |||
111 | 111 | ||
112 | 112 | ||
113 | class CSVDataset(Dataset): | 113 | class CSVDataset(Dataset): |
114 | def __init__(self, | 114 | def __init__( |
115 | data, | 115 | self, |
116 | tokenizer, | 116 | data: List[CSVDataItem], |
117 | instance_identifier, | 117 | tokenizer, |
118 | class_identifier=None, | 118 | instance_identifier, |
119 | num_class_images=2, | 119 | batch_size=1, |
120 | size=512, | 120 | class_identifier=None, |
121 | repeats=1, | 121 | num_class_images=0, |
122 | interpolation="bicubic", | 122 | size=512, |
123 | center_crop=False, | 123 | repeats=1, |
124 | ): | 124 | interpolation="bicubic", |
125 | center_crop=False, | ||
126 | ): | ||
125 | 127 | ||
126 | self.data = data | 128 | self.data = data |
127 | self.tokenizer = tokenizer | 129 | self.tokenizer = tokenizer |
130 | self.batch_size = batch_size | ||
128 | self.instance_identifier = instance_identifier | 131 | self.instance_identifier = instance_identifier |
129 | self.class_identifier = class_identifier | 132 | self.class_identifier = class_identifier |
130 | self.num_class_images = num_class_images | 133 | self.num_class_images = num_class_images |
131 | self.cache = {} | 134 | self.cache = {} |
135 | self.image_cache = {} | ||
132 | 136 | ||
133 | self.num_instance_images = len(self.data) | 137 | self.num_instance_images = len(self.data) |
134 | self._length = self.num_instance_images * repeats | 138 | self._length = self.num_instance_images * repeats |
@@ -149,46 +153,50 @@ class CSVDataset(Dataset): | |||
149 | ) | 153 | ) |
150 | 154 | ||
151 | def __len__(self): | 155 | def __len__(self): |
152 | return self._length | 156 | return math.ceil(self._length / self.batch_size) * self.batch_size |
153 | 157 | ||
154 | def get_example(self, i): | 158 | def get_example(self, i): |
155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 159 | item = self.data[i % self.num_instance_images] |
156 | cache_key = f"{instance_image_path}_{class_image_path}" | 160 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
157 | 161 | ||
158 | if cache_key in self.cache: | 162 | if cache_key in self.cache: |
159 | return self.cache[cache_key] | 163 | return self.cache[cache_key] |
160 | 164 | ||
161 | example = {} | 165 | example = {} |
162 | 166 | ||
163 | example["prompts"] = prompt | 167 | example["prompts"] = item.prompt |
164 | example["nprompts"] = nprompt | 168 | example["nprompts"] = item.nprompt |
165 | 169 | ||
166 | instance_image = Image.open(instance_image_path) | 170 | if item.instance_image_path in self.image_cache: |
167 | if not instance_image.mode == "RGB": | 171 | instance_image = self.image_cache[item.instance_image_path] |
168 | instance_image = instance_image.convert("RGB") | 172 | else: |
173 | instance_image = Image.open(item.instance_image_path) | ||
174 | if not instance_image.mode == "RGB": | ||
175 | instance_image = instance_image.convert("RGB") | ||
176 | self.image_cache[item.instance_image_path] = instance_image | ||
169 | 177 | ||
170 | example["instance_images"] = instance_image | 178 | example["instance_images"] = instance_image |
171 | example["instance_prompt_ids"] = self.tokenizer( | 179 | example["instance_prompt_ids"] = self.tokenizer( |
172 | prompt.format(self.instance_identifier), | 180 | item.prompt.format(self.instance_identifier), |
173 | padding="do_not_pad", | 181 | padding="do_not_pad", |
174 | truncation=True, | 182 | truncation=True, |
175 | max_length=self.tokenizer.model_max_length, | 183 | max_length=self.tokenizer.model_max_length, |
176 | ).input_ids | 184 | ).input_ids |
177 | 185 | ||
178 | if self.num_class_images != 0: | 186 | if self.num_class_images != 0: |
179 | class_image = Image.open(class_image_path) | 187 | class_image = Image.open(item.class_image_path) |
180 | if not class_image.mode == "RGB": | 188 | if not class_image.mode == "RGB": |
181 | class_image = class_image.convert("RGB") | 189 | class_image = class_image.convert("RGB") |
182 | 190 | ||
183 | example["class_images"] = class_image | 191 | example["class_images"] = class_image |
184 | example["class_prompt_ids"] = self.tokenizer( | 192 | example["class_prompt_ids"] = self.tokenizer( |
185 | prompt.format(self.class_identifier), | 193 | item.prompt.format(self.class_identifier), |
186 | padding="do_not_pad", | 194 | padding="do_not_pad", |
187 | truncation=True, | 195 | truncation=True, |
188 | max_length=self.tokenizer.model_max_length, | 196 | max_length=self.tokenizer.model_max_length, |
189 | ).input_ids | 197 | ).input_ids |
190 | 198 | ||
191 | self.cache[instance_image_path] = example | 199 | self.cache[item.instance_image_path] = example |
192 | return example | 200 | return example |
193 | 201 | ||
194 | def __getitem__(self, i): | 202 | def __getitem__(self, i): |