summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
committerVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
commit19ae465203c8dcc0b1179584db632015362b5e44 (patch)
treead6d45e78826f525c336927e4269197667f1f354 /data
parentFix training with guidance (diff)
downloadtextual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.gz
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.bz2
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.zip
Improved inverted tokens
Diffstat (limited to 'data')
-rw-r--r--data/csv.py67
1 files changed, 44 insertions, 23 deletions
diff --git a/data/csv.py b/data/csv.py
index d52d251..9770bec 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -178,6 +178,7 @@ class VlpnDataModule():
178 shuffle: bool = False, 178 shuffle: bool = False,
179 interpolation: str = "bicubic", 179 interpolation: str = "bicubic",
180 template_key: str = "template", 180 template_key: str = "template",
181 placeholder_tokens: list[str] = [],
181 valid_set_size: Optional[int] = None, 182 valid_set_size: Optional[int] = None,
182 train_set_pad: Optional[int] = None, 183 train_set_pad: Optional[int] = None,
183 valid_set_pad: Optional[int] = None, 184 valid_set_pad: Optional[int] = None,
@@ -195,6 +196,7 @@ class VlpnDataModule():
195 self.data_root = self.data_file.parent 196 self.data_root = self.data_file.parent
196 self.class_root = self.data_root / class_subdir 197 self.class_root = self.data_root / class_subdir
197 self.class_root.mkdir(parents=True, exist_ok=True) 198 self.class_root.mkdir(parents=True, exist_ok=True)
199 self.placeholder_tokens = placeholder_tokens
198 self.num_class_images = num_class_images 200 self.num_class_images = num_class_images
199 self.with_guidance = with_guidance 201 self.with_guidance = with_guidance
200 202
@@ -217,31 +219,50 @@ class VlpnDataModule():
217 self.dtype = dtype 219 self.dtype = dtype
218 220
219 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 221 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
220 image = template["image"] if "image" in template else "{}" 222 tpl_image = template["image"] if "image" in template else "{}"
221 prompt = template["prompt"] if "prompt" in template else "{content}" 223 tpl_prompt = template["prompt"] if "prompt" in template else "{content}"
222 cprompt = template["cprompt"] if "cprompt" in template else "{content}" 224 tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}"
223 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 225 tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}"
226
227 items = []
228
229 for item in data:
230 image = tpl_image.format(item["image"])
231 prompt = item["prompt"] if "prompt" in item else ""
232 nprompt = item["nprompt"] if "nprompt" in item else ""
233 collection = item["collection"].split(", ") if "collection" in item else []
234
235 prompt_keywords = prompt_to_keywords(
236 tpl_prompt.format(**prepare_prompt(prompt)),
237 expansions
238 )
224 239
225 return [ 240 cprompt = keywords_to_prompt(prompt_to_keywords(
226 VlpnDataItem( 241 tpl_cprompt.format(**prepare_prompt(prompt)),
227 self.data_root / image.format(item["image"]), 242 expansions
228 None, 243 ))
229 prompt_to_keywords( 244
230 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 245 inverted_tokens = keywords_to_prompt([
231 expansions 246 f"inv_{token}"
232 ), 247 for token in self.placeholder_tokens
233 keywords_to_prompt(prompt_to_keywords( 248 if token in prompt_keywords
234 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 249 ])
235 expansions 250
236 )), 251 nprompt_keywords = prompt_to_keywords(
237 prompt_to_keywords( 252 tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)),
238 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 253 expansions
239 expansions
240 ),
241 item["collection"].split(", ") if "collection" in item else []
242 ) 254 )
243 for item in data 255
244 ] 256 items.append(VlpnDataItem(
257 self.data_root / image,
258 None,
259 prompt_keywords,
260 cprompt,
261 nprompt_keywords,
262 collection
263 ))
264
265 return items
245 266
246 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: 267 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]:
247 if self.filter is None: 268 if self.filter is None: