diff options
author | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
commit | 64c79cc3e7fad49131f90fbb0648b6d5587563e5 (patch) | |
tree | 372bb09a8c952bd28a8da069659da26ce2c99894 | |
parent | Fix sample steps (diff) | |
download | textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.gz textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.bz2 textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.zip |
Various updated; shuffle prompt content during training
-rw-r--r-- | data/csv.py | 29 | ||||
-rw-r--r-- | dreambooth.py | 10 | ||||
-rw-r--r-- | infer.py | 27 |
3 files changed, 48 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | import numpy as np | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
6 | from PIL import Image | 7 | from PIL import Image |
@@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
16 | 17 | ||
17 | 18 | ||
19 | def shuffle_prompt(prompt: str): | ||
20 | def handle_block(block: str): | ||
21 | words = block.split(", ") | ||
22 | np.random.shuffle(words) | ||
23 | return ", ".join(words) | ||
24 | |||
25 | prompt = prompt.split(". ") | ||
26 | prompt = [handle_block(b) for b in prompt] | ||
27 | np.random.shuffle(prompt) | ||
28 | prompt = ". ".join(prompt) | ||
29 | return prompt | ||
30 | |||
31 | |||
18 | class CSVDataItem(NamedTuple): | 32 | class CSVDataItem(NamedTuple): |
19 | instance_image_path: Path | 33 | instance_image_path: Path |
20 | class_image_path: Path | 34 | class_image_path: Path |
@@ -190,30 +204,27 @@ class CSVDataset(Dataset): | |||
190 | item = self.data[i % self.num_instance_images] | 204 | item = self.data[i % self.num_instance_images] |
191 | 205 | ||
192 | example = {} | 206 | example = {} |
193 | |||
194 | example["prompts"] = item.prompt | 207 | example["prompts"] = item.prompt |
195 | example["nprompts"] = item.nprompt | 208 | example["nprompts"] = item.nprompt |
196 | |||
197 | example["instance_images"] = self.get_image(item.instance_image_path) | 209 | example["instance_images"] = self.get_image(item.instance_image_path) |
198 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) | ||
199 | |||
200 | if self.num_class_images != 0: | 210 | if self.num_class_images != 0: |
201 | example["class_images"] = self.get_image(item.class_image_path) | 211 | example["class_images"] = self.get_image(item.class_image_path) |
202 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) | ||
203 | 212 | ||
204 | return example | 213 | return example |
205 | 214 | ||
206 | def __getitem__(self, i): | 215 | def __getitem__(self, i): |
207 | example = {} | ||
208 | unprocessed_example = self.get_example(i) | 216 | unprocessed_example = self.get_example(i) |
209 | 217 | ||
210 | example["prompts"] = unprocessed_example["prompts"] | 218 | example = {} |
219 | |||
220 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | ||
211 | example["nprompts"] = unprocessed_example["nprompts"] | 221 | example["nprompts"] = unprocessed_example["nprompts"] |
222 | |||
212 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 223 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
213 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 224 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) |
214 | 225 | ||
215 | if self.num_class_images != 0: | 226 | if self.num_class_images != 0: |
216 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 227 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
217 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 228 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) |
218 | 229 | ||
219 | return example | 230 | return example |
diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -1,13 +1,11 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | 2 | import itertools |
3 | import math | 3 | import math |
4 | import os | ||
5 | import datetime | 4 | import datetime |
6 | import logging | 5 | import logging |
7 | import json | 6 | import json |
8 | from pathlib import Path | 7 | from pathlib import Path |
9 | 8 | ||
10 | import numpy as np | ||
11 | import torch | 9 | import torch |
12 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
13 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
@@ -299,7 +297,7 @@ def parse_args(): | |||
299 | parser.add_argument( | 297 | parser.add_argument( |
300 | "--sample_steps", | 298 | "--sample_steps", |
301 | type=int, | 299 | type=int, |
302 | default=20, | 300 | default=15, |
303 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
304 | ) | 302 | ) |
305 | parser.add_argument( | 303 | parser.add_argument( |
@@ -613,7 +611,7 @@ def main(): | |||
613 | ) | 611 | ) |
614 | 612 | ||
615 | # Freeze text_encoder and vae | 613 | # Freeze text_encoder and vae |
616 | freeze_params(vae.parameters()) | 614 | vae.requires_grad_(False) |
617 | 615 | ||
618 | if len(args.placeholder_token) != 0: | 616 | if len(args.placeholder_token) != 0: |
619 | print(f"Adding text embeddings: {args.placeholder_token}") | 617 | print(f"Adding text embeddings: {args.placeholder_token}") |
@@ -629,6 +627,10 @@ def main(): | |||
629 | 627 | ||
630 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 628 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
631 | 629 | ||
630 | print(f"Token ID mappings:") | ||
631 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
632 | print(f"- {token_id} {token}") | ||
633 | |||
632 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 634 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
633 | text_encoder.resize_token_embeddings(len(tokenizer)) | 635 | text_encoder.resize_token_embeddings(len(tokenizer)) |
634 | 636 | ||
@@ -8,7 +8,18 @@ from pathlib import Path | |||
8 | import torch | 8 | import torch |
9 | import json | 9 | import json |
10 | from PIL import Image | 10 | from PIL import Image |
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler | 11 | from diffusers import ( |
12 | AutoencoderKL, | ||
13 | UNet2DConditionModel, | ||
14 | PNDMScheduler, | ||
15 | DPMSolverMultistepScheduler, | ||
16 | DPMSolverSinglestepScheduler, | ||
17 | DDIMScheduler, | ||
18 | LMSDiscreteScheduler, | ||
19 | EulerAncestralDiscreteScheduler, | ||
20 | KDPM2DiscreteScheduler, | ||
21 | KDPM2AncestralDiscreteScheduler | ||
22 | ) | ||
12 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
13 | from slugify import slugify | 24 | from slugify import slugify |
14 | 25 | ||
@@ -29,7 +40,7 @@ default_args = { | |||
29 | 40 | ||
30 | 41 | ||
31 | default_cmds = { | 42 | default_cmds = { |
32 | "scheduler": "dpmpp", | 43 | "scheduler": "dpmsm", |
33 | "prompt": None, | 44 | "prompt": None, |
34 | "negative_prompt": None, | 45 | "negative_prompt": None, |
35 | "image": None, | 46 | "image": None, |
@@ -38,7 +49,7 @@ default_cmds = { | |||
38 | "height": 512, | 49 | "height": 512, |
39 | "batch_size": 1, | 50 | "batch_size": 1, |
40 | "batch_num": 1, | 51 | "batch_num": 1, |
41 | "steps": 50, | 52 | "steps": 30, |
42 | "guidance_scale": 7.0, | 53 | "guidance_scale": 7.0, |
43 | "seed": None, | 54 | "seed": None, |
44 | "config": None, | 55 | "config": None, |
@@ -90,7 +101,7 @@ def create_cmd_parser(): | |||
90 | parser.add_argument( | 101 | parser.add_argument( |
91 | "--scheduler", | 102 | "--scheduler", |
92 | type=str, | 103 | type=str, |
93 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | 104 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
94 | ) | 105 | ) |
95 | parser.add_argument( | 106 | parser.add_argument( |
96 | "--prompt", | 107 | "--prompt", |
@@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): | |||
252 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | 263 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) |
253 | elif args.scheduler == "ddim": | 264 | elif args.scheduler == "ddim": |
254 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | 265 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
255 | elif args.scheduler == "dpmpp": | 266 | elif args.scheduler == "dpmsm": |
256 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | 267 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
268 | elif args.scheduler == "dpmss": | ||
269 | pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) | ||
257 | elif args.scheduler == "euler_a": | 270 | elif args.scheduler == "euler_a": |
258 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 271 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
272 | elif args.scheduler == "kdpm2": | ||
273 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | ||
274 | elif args.scheduler == "kdpm2_a": | ||
275 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
259 | 276 | ||
260 | with torch.autocast("cuda"), torch.inference_mode(): | 277 | with torch.autocast("cuda"), torch.inference_mode(): |
261 | for i in range(args.batch_num): | 278 | for i in range(args.batch_num): |