diff options
author | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
commit | 306f2bfb620e6882737658bd3694c79365d75e4b (patch) | |
tree | 8b461c4360b9baa5758c2af0100348f14df8c76d /data | |
parent | Implemented extended prompt limit (diff) | |
download | textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.gz textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.bz2 textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.zip |
Improved prompt handling
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 83 |
1 files changed, 37 insertions, 46 deletions
diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,11 +1,14 @@ | |||
1 | import math | 1 | import math |
2 | import pandas as pd | 2 | import pandas as pd |
3 | import torch | ||
3 | from pathlib import Path | 4 | from pathlib import Path |
4 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
5 | from PIL import Image | 6 | from PIL import Image |
6 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
7 | from torchvision import transforms | 8 | from torchvision import transforms |
8 | from typing import NamedTuple, List | 9 | from typing import NamedTuple, List, Optional |
10 | |||
11 | from models.clip.prompt import PromptProcessor | ||
9 | 12 | ||
10 | 13 | ||
11 | class CSVDataItem(NamedTuple): | 14 | class CSVDataItem(NamedTuple): |
@@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): | |||
18 | class CSVDataModule(pl.LightningDataModule): | 21 | class CSVDataModule(pl.LightningDataModule): |
19 | def __init__( | 22 | def __init__( |
20 | self, | 23 | self, |
21 | batch_size, | 24 | batch_size: int, |
22 | data_file, | 25 | data_file: str, |
23 | tokenizer, | 26 | prompt_processor: PromptProcessor, |
24 | instance_identifier, | 27 | instance_identifier: str, |
25 | class_identifier=None, | 28 | class_identifier: Optional[str] = None, |
26 | class_subdir="cls", | 29 | class_subdir: str = "cls", |
27 | num_class_images=100, | 30 | num_class_images: int = 100, |
28 | size=512, | 31 | size: int = 512, |
29 | repeats=100, | 32 | repeats: int = 1, |
30 | interpolation="bicubic", | 33 | interpolation: str = "bicubic", |
31 | center_crop=False, | 34 | center_crop: bool = False, |
32 | valid_set_size=None, | 35 | valid_set_size: Optional[int] = None, |
33 | generator=None, | 36 | generator: Optional[torch.Generator] = None, |
34 | collate_fn=None | 37 | collate_fn=None |
35 | ): | 38 | ): |
36 | super().__init__() | 39 | super().__init__() |
@@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
45 | self.class_root.mkdir(parents=True, exist_ok=True) | 48 | self.class_root.mkdir(parents=True, exist_ok=True) |
46 | self.num_class_images = num_class_images | 49 | self.num_class_images = num_class_images |
47 | 50 | ||
48 | self.tokenizer = tokenizer | 51 | self.prompt_processor = prompt_processor |
49 | self.instance_identifier = instance_identifier | 52 | self.instance_identifier = instance_identifier |
50 | self.class_identifier = class_identifier | 53 | self.class_identifier = class_identifier |
51 | self.size = size | 54 | self.size = size |
@@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
65 | self.data_root.joinpath(item.image), | 68 | self.data_root.joinpath(item.image), |
66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), |
67 | item.prompt, | 70 | item.prompt, |
68 | item.nprompt if "nprompt" in item else "" | 71 | item.nprompt |
69 | ) | 72 | ) |
70 | for item in data | 73 | for item in data |
71 | for i in range(image_multiplier) | 74 | for i in range(image_multiplier) |
@@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
88 | self.data_val = self.prepare_subdata(data_val) | 91 | self.data_val = self.prepare_subdata(data_val) |
89 | 92 | ||
90 | def setup(self, stage=None): | 93 | def setup(self, stage=None): |
91 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
92 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 95 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
93 | num_class_images=self.num_class_images, | 96 | num_class_images=self.num_class_images, |
94 | size=self.size, interpolation=self.interpolation, | 97 | size=self.size, interpolation=self.interpolation, |
95 | center_crop=self.center_crop, repeats=self.repeats) | 98 | center_crop=self.center_crop, repeats=self.repeats) |
96 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
97 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
98 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
99 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop, repeats=self.repeats) |
@@ -113,19 +116,19 @@ class CSVDataset(Dataset): | |||
113 | def __init__( | 116 | def __init__( |
114 | self, | 117 | self, |
115 | data: List[CSVDataItem], | 118 | data: List[CSVDataItem], |
116 | tokenizer, | 119 | prompt_processor: PromptProcessor, |
117 | instance_identifier, | 120 | instance_identifier: str, |
118 | batch_size=1, | 121 | batch_size: int = 1, |
119 | class_identifier=None, | 122 | class_identifier: Optional[str] = None, |
120 | num_class_images=0, | 123 | num_class_images: int = 0, |
121 | size=512, | 124 | size: int = 512, |
122 | repeats=1, | 125 | repeats: int = 1, |
123 | interpolation="bicubic", | 126 | interpolation: str = "bicubic", |
124 | center_crop=False, | 127 | center_crop: bool = False, |
125 | ): | 128 | ): |
126 | 129 | ||
127 | self.data = data | 130 | self.data = data |
128 | self.tokenizer = tokenizer | 131 | self.prompt_processor = prompt_processor |
129 | self.batch_size = batch_size | 132 | self.batch_size = batch_size |
130 | self.instance_identifier = instance_identifier | 133 | self.instance_identifier = instance_identifier |
131 | self.class_identifier = class_identifier | 134 | self.class_identifier = class_identifier |
@@ -163,12 +166,6 @@ class CSVDataset(Dataset): | |||
163 | 166 | ||
164 | example = {} | 167 | example = {} |
165 | 168 | ||
166 | if isinstance(item.prompt, str): | ||
167 | item.prompt = [item.prompt] | ||
168 | |||
169 | if isinstance(item.nprompt, str): | ||
170 | item.nprompt = [item.nprompt] | ||
171 | |||
172 | example["prompts"] = item.prompt | 169 | example["prompts"] = item.prompt |
173 | example["nprompts"] = item.nprompt | 170 | example["nprompts"] = item.nprompt |
174 | 171 | ||
@@ -181,12 +178,9 @@ class CSVDataset(Dataset): | |||
181 | self.image_cache[item.instance_image_path] = instance_image | 178 | self.image_cache[item.instance_image_path] = instance_image |
182 | 179 | ||
183 | example["instance_images"] = instance_image | 180 | example["instance_images"] = instance_image |
184 | example["instance_prompt_ids"] = self.tokenizer( | 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
185 | item.prompt.format(self.instance_identifier), | 182 | item.prompt.format(self.instance_identifier) |
186 | padding="max_length", | 183 | ) |
187 | truncation=True, | ||
188 | max_length=self.tokenizer.model_max_length, | ||
189 | ).input_ids | ||
190 | 184 | ||
191 | if self.num_class_images != 0: | 185 | if self.num_class_images != 0: |
192 | class_image = Image.open(item.class_image_path) | 186 | class_image = Image.open(item.class_image_path) |
@@ -194,12 +188,9 @@ class CSVDataset(Dataset): | |||
194 | class_image = class_image.convert("RGB") | 188 | class_image = class_image.convert("RGB") |
195 | 189 | ||
196 | example["class_images"] = class_image | 190 | example["class_images"] = class_image |
197 | example["class_prompt_ids"] = self.tokenizer( | 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
198 | item.prompt.format(self.class_identifier), | 192 | item.nprompt.format(self.class_identifier) |
199 | padding="max_length", | 193 | ) |
200 | truncation=True, | ||
201 | max_length=self.tokenizer.model_max_length, | ||
202 | ).input_ids | ||
203 | 194 | ||
204 | self.cache[item.instance_image_path] = example | 195 | self.cache[item.instance_image_path] = example |
205 | return example | 196 | return example |