summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
commit1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch)
tree42d3302610046dbc5d39d254f7a2d5d5f601aa18 /data
parentFix (diff)
downloadtextual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip
Improved class prompt handling
Diffstat (limited to 'data')
-rw-r--r--data/csv.py29
1 files changed, 12 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py
index d400757..e25dd3f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple):
38 instance_image_path: Path 38 instance_image_path: Path
39 class_image_path: Path 39 class_image_path: Path
40 prompt: list[str] 40 prompt: list[str]
41 cprompt: str
41 nprompt: str 42 nprompt: str
42 43
43 44
@@ -47,8 +48,6 @@ class CSVDataModule():
47 batch_size: int, 48 batch_size: int,
48 data_file: str, 49 data_file: str,
49 prompt_processor: PromptProcessor, 50 prompt_processor: PromptProcessor,
50 instance_identifier: str,
51 class_identifier: Optional[str] = None,
52 class_subdir: str = "cls", 51 class_subdir: str = "cls",
53 num_class_images: int = 100, 52 num_class_images: int = 100,
54 size: int = 512, 53 size: int = 512,
@@ -77,8 +76,6 @@ class CSVDataModule():
77 self.num_class_images = num_class_images 76 self.num_class_images = num_class_images
78 77
79 self.prompt_processor = prompt_processor 78 self.prompt_processor = prompt_processor
80 self.instance_identifier = instance_identifier
81 self.class_identifier = class_identifier
82 self.size = size 79 self.size = size
83 self.repeats = repeats 80 self.repeats = repeats
84 self.dropout = dropout 81 self.dropout = dropout
@@ -96,14 +93,18 @@ class CSVDataModule():
96 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: 93 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]:
97 image = template["image"] if "image" in template else "{}" 94 image = template["image"] if "image" in template else "{}"
98 prompt = template["prompt"] if "prompt" in template else "{content}" 95 prompt = template["prompt"] if "prompt" in template else "{content}"
96 cprompt = template["cprompt"] if "cprompt" in template else "{content}"
99 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 97 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
100 98
101 return [ 99 return [
102 CSVDataItem( 100 CSVDataItem(
103 self.data_root.joinpath(image.format(item["image"])), 101 self.data_root.joinpath(image.format(item["image"])),
104 None, 102 None,
105 prompt_to_keywords(prompt.format( 103 prompt_to_keywords(
106 **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), 104 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
105 expansions
106 ),
107 cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")),
107 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 108 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
108 ) 109 )
109 for item in data 110 for item in data
@@ -123,6 +124,7 @@ class CSVDataModule():
123 item.instance_image_path, 124 item.instance_image_path,
124 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 125 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
125 item.prompt, 126 item.prompt,
127 item.cprompt,
126 item.nprompt, 128 item.nprompt,
127 ) 129 )
128 for item in items 130 for item in items
@@ -160,12 +162,10 @@ class CSVDataModule():
160 162
161 def setup(self, stage=None): 163 def setup(self, stage=None):
162 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 164 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
163 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
164 num_class_images=self.num_class_images, 165 num_class_images=self.num_class_images,
165 size=self.size, interpolation=self.interpolation, 166 size=self.size, interpolation=self.interpolation,
166 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) 167 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout)
167 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 168 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
168 instance_identifier=self.instance_identifier,
169 size=self.size, interpolation=self.interpolation, 169 size=self.size, interpolation=self.interpolation,
170 center_crop=self.center_crop) 170 center_crop=self.center_crop)
171 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 171 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
@@ -187,9 +187,7 @@ class CSVDataset(Dataset):
187 self, 187 self,
188 data: List[CSVDataItem], 188 data: List[CSVDataItem],
189 prompt_processor: PromptProcessor, 189 prompt_processor: PromptProcessor,
190 instance_identifier: str,
191 batch_size: int = 1, 190 batch_size: int = 1,
192 class_identifier: Optional[str] = None,
193 num_class_images: int = 0, 191 num_class_images: int = 0,
194 size: int = 512, 192 size: int = 512,
195 repeats: int = 1, 193 repeats: int = 1,
@@ -201,8 +199,6 @@ class CSVDataset(Dataset):
201 self.data = data 199 self.data = data
202 self.prompt_processor = prompt_processor 200 self.prompt_processor = prompt_processor
203 self.batch_size = batch_size 201 self.batch_size = batch_size
204 self.instance_identifier = instance_identifier
205 self.class_identifier = class_identifier
206 self.num_class_images = num_class_images 202 self.num_class_images = num_class_images
207 self.dropout = dropout 203 self.dropout = dropout
208 self.image_cache = {} 204 self.image_cache = {}
@@ -239,14 +235,12 @@ class CSVDataset(Dataset):
239 235
240 return image 236 return image
241 237
242 def get_input_ids(self, prompt, identifier):
243 return self.prompt_processor.get_input_ids(prompt.format(identifier))
244
245 def get_example(self, i): 238 def get_example(self, i):
246 item = self.data[i % self.num_instance_images] 239 item = self.data[i % self.num_instance_images]
247 240
248 example = {} 241 example = {}
249 example["prompts"] = item.prompt 242 example["prompts"] = item.prompt
243 example["cprompts"] = item.cprompt
250 example["nprompts"] = item.nprompt 244 example["nprompts"] = item.nprompt
251 example["instance_images"] = self.get_image(item.instance_image_path) 245 example["instance_images"] = self.get_image(item.instance_image_path)
252 if self.num_class_images != 0: 246 if self.num_class_images != 0:
@@ -260,13 +254,14 @@ class CSVDataset(Dataset):
260 example = {} 254 example = {}
261 255
262 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) 256 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout)
257 example["cprompts"] = unprocessed_example["cprompts"]
263 example["nprompts"] = unprocessed_example["nprompts"] 258 example["nprompts"] = unprocessed_example["nprompts"]
264 259
265 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 260 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
266 example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) 261 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"])
267 262
268 if self.num_class_images != 0: 263 if self.num_class_images != 0:
269 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 264 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
270 example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) 265 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
271 266
272 return example 267 return example