summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py22
-rw-r--r--data/keywords.py21
-rw-r--r--infer.py40
-rw-r--r--models/clip/prompt.py10
-rw-r--r--train_dreambooth.py14
-rw-r--r--train_ti.py14
-rw-r--r--training/common.py5
7 files changed, 75 insertions, 51 deletions
diff --git a/data/csv.py b/data/csv.py
index a60733a..d1f3054 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,7 +1,6 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import numpy as np
5from pathlib import Path 4from pathlib import Path
6from PIL import Image 5from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split 6from torch.utils.data import Dataset, DataLoader, random_split
@@ -9,32 +8,13 @@ from torchvision import transforms
9from typing import Dict, NamedTuple, List, Optional, Union, Callable 8from typing import Dict, NamedTuple, List, Optional, Union, Callable
10 9
11from models.clip.prompt import PromptProcessor 10from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt
12 12
13 13
14def prepare_prompt(prompt: Union[str, Dict[str, str]]): 14def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 15 return {"content": prompt} if isinstance(prompt, str) else prompt
16 16
17 17
18def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
19 if dropout != 0:
20 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
21 if shuffle:
22 np.random.shuffle(prompt)
23 return ", ".join(prompt)
24
25
26def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]:
27 def expand_keyword(keyword: str) -> list[str]:
28 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
29
30 return [
31 kw
32 for keyword in prompt.split(", ")
33 for kw in expand_keyword(keyword)
34 if keyword != ""
35 ]
36
37
38class CSVDataItem(NamedTuple): 18class CSVDataItem(NamedTuple):
39 instance_image_path: Path 19 instance_image_path: Path
40 class_image_path: Path 20 class_image_path: Path
diff --git a/data/keywords.py b/data/keywords.py
new file mode 100644
index 0000000..9e656f3
--- /dev/null
+++ b/data/keywords.py
@@ -0,0 +1,21 @@
1import numpy as np
2
3
4def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str:
5 if dropout != 0:
6 prompt = [keyword for keyword in prompt if np.random.random() > dropout]
7 if shuffle:
8 np.random.shuffle(prompt)
9 return ", ".join(prompt)
10
11
12def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]:
13 def expand_keyword(keyword: str) -> list[str]:
14 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword]
15
16 return [
17 kw
18 for keyword in prompt.split(", ")
19 for kw in expand_keyword(keyword)
20 if keyword != ""
21 ]
diff --git a/infer.py b/infer.py
index 507d0cf..9c27db4 100644
--- a/infer.py
+++ b/infer.py
@@ -25,6 +25,7 @@ from diffusers import (
25) 25)
26from transformers import CLIPTextModel 26from transformers import CLIPTextModel
27 27
28from data.keywords import prompt_to_keywords, keywords_to_prompt
28from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
29from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
@@ -49,6 +50,7 @@ default_cmds = {
49 "scheduler": "dpmsm", 50 "scheduler": "dpmsm",
50 "prompt": None, 51 "prompt": None,
51 "negative_prompt": None, 52 "negative_prompt": None,
53 "shuffle": True,
52 "image": None, 54 "image": None,
53 "image_noise": .7, 55 "image_noise": .7,
54 "width": 768, 56 "width": 768,
@@ -126,6 +128,10 @@ def create_cmd_parser():
126 nargs="*", 128 nargs="*",
127 ) 129 )
128 parser.add_argument( 130 parser.add_argument(
131 "--shuffle",
132 type=bool,
133 )
134 parser.add_argument(
129 "--image", 135 "--image",
130 type=str, 136 type=str,
131 ) 137 )
@@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir):
197 pipeline.text_encoder.text_model.embeddings, 203 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir) 204 Path(embeddings_dir)
199 ) 205 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") 206 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
201 207
202 208
203def create_pipeline(model, dtype): 209def create_pipeline(model, dtype):
@@ -228,20 +234,35 @@ def create_pipeline(model, dtype):
228 234
229 235
230@torch.inference_mode() 236@torch.inference_mode()
231def generate(output_dir, pipeline, args): 237def generate(output_dir: Path, pipeline, args):
232 if isinstance(args.prompt, str): 238 if isinstance(args.prompt, str):
233 args.prompt = [args.prompt] 239 args.prompt = [args.prompt]
234 240
241 if args.shuffle:
242 args.prompt *= args.batch_size
243 args.batch_size = 1
244 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt]
245
235 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 246 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
236 use_subdirs = len(args.prompt) != 1 247 image_dir = []
237 if use_subdirs: 248
249 if len(args.prompt) != 1:
238 if len(args.project) != 0: 250 if len(args.project) != 0:
239 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") 251 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}")
240 else: 252 else:
241 output_dir = output_dir.joinpath(now) 253 output_dir = output_dir.joinpath(now)
254
255 for prompt in args.prompt:
256 dir = output_dir.joinpath(slugify(prompt)[:100])
257 dir.mkdir(parents=True, exist_ok=True)
258 image_dir.append(dir)
259
260 with open(dir.joinpath('prompt.txt'), 'w') as f:
261 f.write(prompt)
242 else: 262 else:
243 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 263 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
244 output_dir.mkdir(parents=True, exist_ok=True) 264 output_dir.mkdir(parents=True, exist_ok=True)
265 image_dir.append(output_dir)
245 266
246 args.seed = args.seed or torch.random.seed() 267 args.seed = args.seed or torch.random.seed()
247 268
@@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args):
293 ).images 314 ).images
294 315
295 for j, image in enumerate(images): 316 for j, image in enumerate(images):
296 image_dir = output_dir 317 dir = image_dir[j % len(args.prompt)]
297 if use_subdirs: 318 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png"))
298 image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) 319 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85)
299 image_dir.mkdir(parents=True, exist_ok=True)
300 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png"))
301 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85)
302 320
303 if torch.cuda.is_available(): 321 if torch.cuda.is_available():
304 torch.cuda.empty_cache() 322 torch.cuda.empty_cache()
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index 9da3955..a7380be 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import Union 1from typing import Union, Optional
2 2
3import torch 3import torch
4 4
@@ -16,7 +16,7 @@ class PromptProcessor():
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: list[int]): 19 def unify_input_ids(self, input_ids: list[list[int]]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,
@@ -24,13 +24,15 @@ class PromptProcessor():
24 return_tensors="pt" 24 return_tensors="pt"
25 ) 25 )
26 26
27 def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): 27 def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
28 prompts = input_ids.shape[0] 28 prompts = input_ids.shape[0]
29 29
30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 30 input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if position_ids is not None:
32 position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
31 if attention_mask is not None: 33 if attention_mask is not None:
32 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) 34 attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
33 35
34 text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] 36 text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
35 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 37 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
36 return text_embeddings 38 return text_embeddings
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1fd86b1..4d1e0a3 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -861,13 +861,13 @@ def main():
861 861
862 loop = partial( 862 loop = partial(
863 run_model, 863 run_model,
864 vae=vae, 864 vae,
865 noise_scheduler=noise_scheduler, 865 noise_scheduler,
866 unet=unet, 866 unet,
867 prompt_processor=prompt_processor, 867 prompt_processor,
868 num_class_images=args.num_class_images, 868 args.num_class_images,
869 prior_loss_weight=args.prior_loss_weight, 869 args.prior_loss_weight,
870 seed=args.seed, 870 args.seed,
871 ) 871 )
872 872
873 # We need to initialize the trackers we use, and also store our configuration. 873 # We need to initialize the trackers we use, and also store our configuration.
diff --git a/train_ti.py b/train_ti.py
index 164cf67..98385dd 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -814,13 +814,13 @@ def main():
814 814
815 loop = partial( 815 loop = partial(
816 run_model, 816 run_model,
817 vae=vae, 817 vae,
818 noise_scheduler=noise_scheduler, 818 noise_scheduler,
819 unet=unet, 819 unet,
820 prompt_processor=prompt_processor, 820 prompt_processor,
821 num_class_images=args.num_class_images, 821 args.num_class_images,
822 prior_loss_weight=args.prior_loss_weight, 822 args.prior_loss_weight,
823 seed=args.seed, 823 args.seed,
824 ) 824 )
825 825
826 # We need to initialize the trackers we use, and also store our configuration. 826 # We need to initialize the trackers we use, and also store our configuration.
diff --git a/training/common.py b/training/common.py
index 99a6e67..ab2741a 100644
--- a/training/common.py
+++ b/training/common.py
@@ -40,7 +40,10 @@ def run_model(
40 noisy_latents = noisy_latents.to(dtype=unet.dtype) 40 noisy_latents = noisy_latents.to(dtype=unet.dtype)
41 41
42 # Get the text embedding for conditioning 42 # Get the text embedding for conditioning
43 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 43 encoder_hidden_states = prompt_processor.get_embeddings(
44 batch["input_ids"],
45 batch["attention_mask"]
46 )
44 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 47 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
45 48
46 # Predict the noise residual 49 # Predict the noise residual