diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
| commit | 3396ca881ed3f3521617cd9024eea56975191d32 (patch) | |
| tree | 3189c3bbe77b211152d11b524d0fe3a7016441ee | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.gz textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.bz2 textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.zip | |
Update
| -rw-r--r-- | data/csv.py | 22 | ||||
| -rw-r--r-- | data/keywords.py | 21 | ||||
| -rw-r--r-- | infer.py | 40 | ||||
| -rw-r--r-- | models/clip/prompt.py | 10 | ||||
| -rw-r--r-- | train_dreambooth.py | 14 | ||||
| -rw-r--r-- | train_ti.py | 14 | ||||
| -rw-r--r-- | training/common.py | 5 |
7 files changed, 75 insertions, 51 deletions
diff --git a/data/csv.py b/data/csv.py index a60733a..d1f3054 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,7 +1,6 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | import numpy as np | ||
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | from PIL import Image | 5 | from PIL import Image |
| 7 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
| @@ -9,32 +8,13 @@ from torchvision import transforms | |||
| 9 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
| 10 | 9 | ||
| 11 | from models.clip.prompt import PromptProcessor | 10 | from models.clip.prompt import PromptProcessor |
| 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | ||
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): |
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 15 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | ||
| 19 | if dropout != 0: | ||
| 20 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | ||
| 21 | if shuffle: | ||
| 22 | np.random.shuffle(prompt) | ||
| 23 | return ", ".join(prompt) | ||
| 24 | |||
| 25 | |||
| 26 | def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: | ||
| 27 | def expand_keyword(keyword: str) -> list[str]: | ||
| 28 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
| 29 | |||
| 30 | return [ | ||
| 31 | kw | ||
| 32 | for keyword in prompt.split(", ") | ||
| 33 | for kw in expand_keyword(keyword) | ||
| 34 | if keyword != "" | ||
| 35 | ] | ||
| 36 | |||
| 37 | |||
| 38 | class CSVDataItem(NamedTuple): | 18 | class CSVDataItem(NamedTuple): |
| 39 | instance_image_path: Path | 19 | instance_image_path: Path |
| 40 | class_image_path: Path | 20 | class_image_path: Path |
diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py | |||
| @@ -0,0 +1,21 @@ | |||
| 1 | import numpy as np | ||
| 2 | |||
| 3 | |||
| 4 | def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: | ||
| 5 | if dropout != 0: | ||
| 6 | prompt = [keyword for keyword in prompt if np.random.random() > dropout] | ||
| 7 | if shuffle: | ||
| 8 | np.random.shuffle(prompt) | ||
| 9 | return ", ".join(prompt) | ||
| 10 | |||
| 11 | |||
| 12 | def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: | ||
| 13 | def expand_keyword(keyword: str) -> list[str]: | ||
| 14 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | ||
| 15 | |||
| 16 | return [ | ||
| 17 | kw | ||
| 18 | for keyword in prompt.split(", ") | ||
| 19 | for kw in expand_keyword(keyword) | ||
| 20 | if keyword != "" | ||
| 21 | ] | ||
| @@ -25,6 +25,7 @@ from diffusers import ( | |||
| 25 | ) | 25 | ) |
| 26 | from transformers import CLIPTextModel | 26 | from transformers import CLIPTextModel |
| 27 | 27 | ||
| 28 | from data.keywords import prompt_to_keywords, keywords_to_prompt | ||
| 28 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| 29 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 30 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| @@ -49,6 +50,7 @@ default_cmds = { | |||
| 49 | "scheduler": "dpmsm", | 50 | "scheduler": "dpmsm", |
| 50 | "prompt": None, | 51 | "prompt": None, |
| 51 | "negative_prompt": None, | 52 | "negative_prompt": None, |
| 53 | "shuffle": True, | ||
| 52 | "image": None, | 54 | "image": None, |
| 53 | "image_noise": .7, | 55 | "image_noise": .7, |
| 54 | "width": 768, | 56 | "width": 768, |
| @@ -126,6 +128,10 @@ def create_cmd_parser(): | |||
| 126 | nargs="*", | 128 | nargs="*", |
| 127 | ) | 129 | ) |
| 128 | parser.add_argument( | 130 | parser.add_argument( |
| 131 | "--shuffle", | ||
| 132 | type=bool, | ||
| 133 | ) | ||
| 134 | parser.add_argument( | ||
| 129 | "--image", | 135 | "--image", |
| 130 | type=str, | 136 | type=str, |
| 131 | ) | 137 | ) |
| @@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir): | |||
| 197 | pipeline.text_encoder.text_model.embeddings, | 203 | pipeline.text_encoder.text_model.embeddings, |
| 198 | Path(embeddings_dir) | 204 | Path(embeddings_dir) |
| 199 | ) | 205 | ) |
| 200 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 206 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 201 | 207 | ||
| 202 | 208 | ||
| 203 | def create_pipeline(model, dtype): | 209 | def create_pipeline(model, dtype): |
| @@ -228,20 +234,35 @@ def create_pipeline(model, dtype): | |||
| 228 | 234 | ||
| 229 | 235 | ||
| 230 | @torch.inference_mode() | 236 | @torch.inference_mode() |
| 231 | def generate(output_dir, pipeline, args): | 237 | def generate(output_dir: Path, pipeline, args): |
| 232 | if isinstance(args.prompt, str): | 238 | if isinstance(args.prompt, str): |
| 233 | args.prompt = [args.prompt] | 239 | args.prompt = [args.prompt] |
| 234 | 240 | ||
| 241 | if args.shuffle: | ||
| 242 | args.prompt *= args.batch_size | ||
| 243 | args.batch_size = 1 | ||
| 244 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
| 245 | |||
| 235 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 246 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 236 | use_subdirs = len(args.prompt) != 1 | 247 | image_dir = [] |
| 237 | if use_subdirs: | 248 | |
| 249 | if len(args.prompt) != 1: | ||
| 238 | if len(args.project) != 0: | 250 | if len(args.project) != 0: |
| 239 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") | 251 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") |
| 240 | else: | 252 | else: |
| 241 | output_dir = output_dir.joinpath(now) | 253 | output_dir = output_dir.joinpath(now) |
| 254 | |||
| 255 | for prompt in args.prompt: | ||
| 256 | dir = output_dir.joinpath(slugify(prompt)[:100]) | ||
| 257 | dir.mkdir(parents=True, exist_ok=True) | ||
| 258 | image_dir.append(dir) | ||
| 259 | |||
| 260 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
| 261 | f.write(prompt) | ||
| 242 | else: | 262 | else: |
| 243 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 263 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 244 | output_dir.mkdir(parents=True, exist_ok=True) | 264 | output_dir.mkdir(parents=True, exist_ok=True) |
| 265 | image_dir.append(output_dir) | ||
| 245 | 266 | ||
| 246 | args.seed = args.seed or torch.random.seed() | 267 | args.seed = args.seed or torch.random.seed() |
| 247 | 268 | ||
| @@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args): | |||
| 293 | ).images | 314 | ).images |
| 294 | 315 | ||
| 295 | for j, image in enumerate(images): | 316 | for j, image in enumerate(images): |
| 296 | image_dir = output_dir | 317 | dir = image_dir[j % len(args.prompt)] |
| 297 | if use_subdirs: | 318 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) |
| 298 | image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) | 319 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) |
| 299 | image_dir.mkdir(parents=True, exist_ok=True) | ||
| 300 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | ||
| 301 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | ||
| 302 | 320 | ||
| 303 | if torch.cuda.is_available(): | 321 | if torch.cuda.is_available(): |
| 304 | torch.cuda.empty_cache() | 322 | torch.cuda.empty_cache() |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9da3955..a7380be 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import Union | 1 | from typing import Union, Optional |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
| @@ -16,7 +16,7 @@ class PromptProcessor(): | |||
| 16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
| 17 | ).input_ids | 17 | ).input_ids |
| 18 | 18 | ||
| 19 | def unify_input_ids(self, input_ids: list[int]): | 19 | def unify_input_ids(self, input_ids: list[list[int]]): |
| 20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
| 21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
| 22 | padding=True, | 22 | padding=True, |
| @@ -24,13 +24,15 @@ class PromptProcessor(): | |||
| 24 | return_tensors="pt" | 24 | return_tensors="pt" |
| 25 | ) | 25 | ) |
| 26 | 26 | ||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): | 27 | def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): |
| 28 | prompts = input_ids.shape[0] | 28 | prompts = input_ids.shape[0] |
| 29 | 29 | ||
| 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 30 | input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 31 | if position_ids is not None: | ||
| 32 | position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 31 | if attention_mask is not None: | 33 | if attention_mask is not None: |
| 32 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | 34 | attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) |
| 33 | 35 | ||
| 34 | text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] | 36 | text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] |
| 35 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 37 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 36 | return text_embeddings | 38 | return text_embeddings |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1fd86b1..4d1e0a3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -861,13 +861,13 @@ def main(): | |||
| 861 | 861 | ||
| 862 | loop = partial( | 862 | loop = partial( |
| 863 | run_model, | 863 | run_model, |
| 864 | vae=vae, | 864 | vae, |
| 865 | noise_scheduler=noise_scheduler, | 865 | noise_scheduler, |
| 866 | unet=unet, | 866 | unet, |
| 867 | prompt_processor=prompt_processor, | 867 | prompt_processor, |
| 868 | num_class_images=args.num_class_images, | 868 | args.num_class_images, |
| 869 | prior_loss_weight=args.prior_loss_weight, | 869 | args.prior_loss_weight, |
| 870 | seed=args.seed, | 870 | args.seed, |
| 871 | ) | 871 | ) |
| 872 | 872 | ||
| 873 | # We need to initialize the trackers we use, and also store our configuration. | 873 | # We need to initialize the trackers we use, and also store our configuration. |
diff --git a/train_ti.py b/train_ti.py index 164cf67..98385dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -814,13 +814,13 @@ def main(): | |||
| 814 | 814 | ||
| 815 | loop = partial( | 815 | loop = partial( |
| 816 | run_model, | 816 | run_model, |
| 817 | vae=vae, | 817 | vae, |
| 818 | noise_scheduler=noise_scheduler, | 818 | noise_scheduler, |
| 819 | unet=unet, | 819 | unet, |
| 820 | prompt_processor=prompt_processor, | 820 | prompt_processor, |
| 821 | num_class_images=args.num_class_images, | 821 | args.num_class_images, |
| 822 | prior_loss_weight=args.prior_loss_weight, | 822 | args.prior_loss_weight, |
| 823 | seed=args.seed, | 823 | args.seed, |
| 824 | ) | 824 | ) |
| 825 | 825 | ||
| 826 | # We need to initialize the trackers we use, and also store our configuration. | 826 | # We need to initialize the trackers we use, and also store our configuration. |
diff --git a/training/common.py b/training/common.py index 99a6e67..ab2741a 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -40,7 +40,10 @@ def run_model( | |||
| 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 41 | 41 | ||
| 42 | # Get the text embedding for conditioning | 42 | # Get the text embedding for conditioning |
| 43 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 43 | encoder_hidden_states = prompt_processor.get_embeddings( |
| 44 | batch["input_ids"], | ||
| 45 | batch["attention_mask"] | ||
| 46 | ) | ||
| 44 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 47 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
| 45 | 48 | ||
| 46 | # Predict the noise residual | 49 | # Predict the noise residual |
