summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py67
-rw-r--r--data/keywords.py12
-rw-r--r--infer.py4
-rw-r--r--train_ti.py2
-rw-r--r--training/functional.py9
5 files changed, 46 insertions, 48 deletions
diff --git a/data/csv.py b/data/csv.py
index c00ea07..d0ac317 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -11,7 +11,7 @@ from torch.utils.data import IterableDataset, DataLoader, random_split
11from torchvision import transforms 11from torchvision import transforms
12from transformers import CLIPTokenizer 12from transformers import CLIPTokenizer
13 13
14from data.keywords import prompt_to_keywords, keywords_to_prompt 14from data.keywords import str_to_keywords, keywords_to_str
15from models.clip.util import unify_input_ids 15from models.clip.util import unify_input_ids
16 16
17 17
@@ -37,7 +37,7 @@ def get_image(path):
37 return image 37 return image
38 38
39 39
40def prepare_prompt(prompt: Union[str, dict[str, str]]): 40def prepare_tpl_slots(prompt: Union[str, dict[str, str]]):
41 return {"content": prompt} if isinstance(prompt, str) else prompt 41 return {"content": prompt} if isinstance(prompt, str) else prompt
42 42
43 43
@@ -135,11 +135,18 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool
135class VlpnDataItem(NamedTuple): 135class VlpnDataItem(NamedTuple):
136 instance_image_path: Path 136 instance_image_path: Path
137 class_image_path: Path 137 class_image_path: Path
138 prompt: list[str] 138 keywords: list[str]
139 prompt: str
139 cprompt: str 140 cprompt: str
140 nprompt: list[str] 141 nprompt: str
141 collection: list[str] 142 collection: list[str]
142 143
144 def full_prompt(self, dropout: float = 0, shuffle: bool = False):
145 prompt = self.prompt
146 if len(self.keywords):
147 prompt += ", " + keywords_to_str(self.keywords, dropout, shuffle)
148 return prompt
149
143 150
144def keyword_filter( 151def keyword_filter(
145 placeholder_tokens: Optional[list[str]], 152 placeholder_tokens: Optional[list[str]],
@@ -147,10 +154,11 @@ def keyword_filter(
147 exclude_collections: Optional[list[str]], 154 exclude_collections: Optional[list[str]],
148 item: VlpnDataItem 155 item: VlpnDataItem
149): 156):
157 full_prompt = item.full_prompt()
158
150 cond1 = placeholder_tokens is None or any( 159 cond1 = placeholder_tokens is None or any(
151 keyword in part 160 token in full_prompt
152 for keyword in placeholder_tokens 161 for token in placeholder_tokens
153 for part in item.prompt
154 ) 162 )
155 cond2 = collections is None or any( 163 cond2 = collections is None or any(
156 collection in item.collection 164 collection in item.collection
@@ -224,6 +232,7 @@ class VlpnDataModule():
224 232
225 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 233 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
226 tpl_image = template["image"] if "image" in template else "{}" 234 tpl_image = template["image"] if "image" in template else "{}"
235 tpl_keywords = template["keywords"] if "keywords" in template else "{content}"
227 tpl_prompt = template["prompt"] if "prompt" in template else "{content}" 236 tpl_prompt = template["prompt"] if "prompt" in template else "{content}"
228 tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" 237 tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}"
229 tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" 238 tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}"
@@ -232,37 +241,26 @@ class VlpnDataModule():
232 241
233 for item in data: 242 for item in data:
234 image = tpl_image.format(item["image"]) 243 image = tpl_image.format(item["image"])
235 prompt = item["prompt"] if "prompt" in item else "" 244 keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "")
236 nprompt = item["nprompt"] if "nprompt" in item else "" 245 prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "")
246 nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "")
237 collection = item["collection"].split(", ") if "collection" in item else [] 247 collection = item["collection"].split(", ") if "collection" in item else []
238 248
239 prompt_keywords = prompt_to_keywords( 249 saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions)
240 tpl_prompt.format(**prepare_prompt(prompt)),
241 expansions
242 )
243
244 cprompt = keywords_to_prompt(prompt_to_keywords(
245 tpl_cprompt.format(**prepare_prompt(prompt)),
246 expansions
247 ))
248 250
249 inverted_tokens = keywords_to_prompt([ 251 inverted_tokens = keywords_to_str([
250 f"inv_{token}" 252 f"inv_{token}"
251 for token in self.placeholder_tokens 253 for token in self.placeholder_tokens
252 if token in prompt_keywords 254 if token in saturated_keywords
253 ]) 255 ])
254 256
255 nprompt_keywords = prompt_to_keywords(
256 tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)),
257 expansions
258 )
259
260 items.append(VlpnDataItem( 257 items.append(VlpnDataItem(
261 self.data_root / image, 258 self.data_root / image,
262 None, 259 None,
263 prompt_keywords, 260 saturated_keywords,
264 cprompt, 261 tpl_prompt.format(**prompt),
265 nprompt_keywords, 262 tpl_cprompt.format(**prompt),
263 tpl_nprompt.format(_inv=inverted_tokens, **nprompt),
266 collection 264 collection
267 )) 265 ))
268 266
@@ -281,6 +279,7 @@ class VlpnDataModule():
281 VlpnDataItem( 279 VlpnDataItem(
282 item.instance_image_path, 280 item.instance_image_path,
283 self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", 281 self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}",
282 item.keywords,
284 item.prompt, 283 item.prompt,
285 item.cprompt, 284 item.cprompt,
286 item.nprompt, 285 item.nprompt,
@@ -473,15 +472,11 @@ class VlpnDataset(IterableDataset):
473 472
474 example = {} 473 example = {}
475 474
476 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) 475 example["prompt_ids"] = self.get_input_ids(item.full_prompt())
477 example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) 476 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
478 477
479 example["instance_prompt_ids"] = self.get_input_ids( 478 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True))
480 keywords_to_prompt(item.prompt, self.dropout, True) 479 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
481 )
482 example["negative_prompt_ids"] = self.get_input_ids(
483 keywords_to_prompt(item.nprompt, self.dropout, True)
484 )
485 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 480 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
486 481
487 if self.num_class_images != 0: 482 if self.num_class_images != 0:
diff --git a/data/keywords.py b/data/keywords.py
index 9e656f3..7385809 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,21 +1,21 @@
1import numpy as np 1import numpy as np
2 2
3 3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: 4def keywords_to_str(keywords: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0: 5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout] 6 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 if shuffle: 7 if shuffle:
8 np.random.shuffle(prompt) 8 np.random.shuffle(keywords)
9 return ", ".join(prompt) 9 return ", ".join(keywords)
10 10
11 11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: 12def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]: 13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] 14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15 15
16 return [ 16 return [
17 kw 17 kw
18 for keyword in prompt.split(", ") 18 for keyword in s.split(", ")
19 for kw in expand_keyword(keyword) 19 for kw in expand_keyword(keyword)
20 if keyword != "" 20 if keyword != ""
21 ] 21 ]
diff --git a/infer.py b/infer.py
index cf59bba..ed86ab1 100644
--- a/infer.py
+++ b/infer.py
@@ -28,7 +28,7 @@ from diffusers import (
28) 28)
29from transformers import CLIPTextModel 29from transformers import CLIPTextModel
30 30
31from data.keywords import prompt_to_keywords, keywords_to_prompt 31from data.keywords import str_to_keywords, keywords_to_str
32from models.clip.embeddings import patch_managed_embeddings 32from models.clip.embeddings import patch_managed_embeddings
33from models.clip.tokenizer import MultiCLIPTokenizer 33from models.clip.tokenizer import MultiCLIPTokenizer
34from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 34from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
@@ -296,7 +296,7 @@ def create_pipeline(model, dtype):
296 296
297 297
298def shuffle_prompts(prompts: list[str]) -> list[str]: 298def shuffle_prompts(prompts: list[str]) -> list[str]:
299 return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] 299 return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts]
300 300
301 301
302@torch.inference_mode() 302@torch.inference_mode()
diff --git a/train_ti.py b/train_ti.py
index 5482326..651dfbe 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -186,7 +186,7 @@ def parse_args():
186 parser.add_argument( 186 parser.add_argument(
187 "--tag_dropout", 187 "--tag_dropout",
188 type=float, 188 type=float,
189 default=0, 189 default=0.1,
190 help="Tag dropout probability.", 190 help="Tag dropout probability.",
191 ) 191 )
192 parser.add_argument( 192 parser.add_argument(
diff --git a/training/functional.py b/training/functional.py
index b9fb546..96ecbc1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -522,9 +522,12 @@ def train_loop(
522 522
523 accelerator.wait_for_everyone() 523 accelerator.wait_for_everyone()
524 524
525 lr = lr_scheduler.get_last_lr()[0] 525 if isDadaptation:
526 if torch.is_tensor(lr): 526 lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
527 lr = lr.item() 527 else:
528 lr = lr_scheduler.get_last_lr()[0]
529 if torch.is_tensor(lr):
530 lr = lr.item()
528 531
529 lrs.append(lr) 532 lrs.append(lr)
530 533