diff options
| -rw-r--r-- | data/csv.py | 29 | ||||
| -rw-r--r-- | dreambooth.py | 10 | ||||
| -rw-r--r-- | infer.py | 27 |
3 files changed, 48 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | 2 | import torch |
| 3 | import json | 3 | import json |
| 4 | import numpy as np | ||
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| 5 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
| 6 | from PIL import Image | 7 | from PIL import Image |
| @@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 16 | 17 | ||
| 17 | 18 | ||
| 19 | def shuffle_prompt(prompt: str): | ||
| 20 | def handle_block(block: str): | ||
| 21 | words = block.split(", ") | ||
| 22 | np.random.shuffle(words) | ||
| 23 | return ", ".join(words) | ||
| 24 | |||
| 25 | prompt = prompt.split(". ") | ||
| 26 | prompt = [handle_block(b) for b in prompt] | ||
| 27 | np.random.shuffle(prompt) | ||
| 28 | prompt = ". ".join(prompt) | ||
| 29 | return prompt | ||
| 30 | |||
| 31 | |||
| 18 | class CSVDataItem(NamedTuple): | 32 | class CSVDataItem(NamedTuple): |
| 19 | instance_image_path: Path | 33 | instance_image_path: Path |
| 20 | class_image_path: Path | 34 | class_image_path: Path |
| @@ -190,30 +204,27 @@ class CSVDataset(Dataset): | |||
| 190 | item = self.data[i % self.num_instance_images] | 204 | item = self.data[i % self.num_instance_images] |
| 191 | 205 | ||
| 192 | example = {} | 206 | example = {} |
| 193 | |||
| 194 | example["prompts"] = item.prompt | 207 | example["prompts"] = item.prompt |
| 195 | example["nprompts"] = item.nprompt | 208 | example["nprompts"] = item.nprompt |
| 196 | |||
| 197 | example["instance_images"] = self.get_image(item.instance_image_path) | 209 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 198 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) | ||
| 199 | |||
| 200 | if self.num_class_images != 0: | 210 | if self.num_class_images != 0: |
| 201 | example["class_images"] = self.get_image(item.class_image_path) | 211 | example["class_images"] = self.get_image(item.class_image_path) |
| 202 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) | ||
| 203 | 212 | ||
| 204 | return example | 213 | return example |
| 205 | 214 | ||
| 206 | def __getitem__(self, i): | 215 | def __getitem__(self, i): |
| 207 | example = {} | ||
| 208 | unprocessed_example = self.get_example(i) | 216 | unprocessed_example = self.get_example(i) |
| 209 | 217 | ||
| 210 | example["prompts"] = unprocessed_example["prompts"] | 218 | example = {} |
| 219 | |||
| 220 | example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) | ||
| 211 | example["nprompts"] = unprocessed_example["nprompts"] | 221 | example["nprompts"] = unprocessed_example["nprompts"] |
| 222 | |||
| 212 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 223 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 213 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 224 | example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) |
| 214 | 225 | ||
| 215 | if self.num_class_images != 0: | 226 | if self.num_class_images != 0: |
| 216 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 227 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 217 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 228 | example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) |
| 218 | 229 | ||
| 219 | return example | 230 | return example |
diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -1,13 +1,11 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | 2 | import itertools |
| 3 | import math | 3 | import math |
| 4 | import os | ||
| 5 | import datetime | 4 | import datetime |
| 6 | import logging | 5 | import logging |
| 7 | import json | 6 | import json |
| 8 | from pathlib import Path | 7 | from pathlib import Path |
| 9 | 8 | ||
| 10 | import numpy as np | ||
| 11 | import torch | 9 | import torch |
| 12 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
| 13 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| @@ -299,7 +297,7 @@ def parse_args(): | |||
| 299 | parser.add_argument( | 297 | parser.add_argument( |
| 300 | "--sample_steps", | 298 | "--sample_steps", |
| 301 | type=int, | 299 | type=int, |
| 302 | default=20, | 300 | default=15, |
| 303 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 304 | ) | 302 | ) |
| 305 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -613,7 +611,7 @@ def main(): | |||
| 613 | ) | 611 | ) |
| 614 | 612 | ||
| 615 | # Freeze text_encoder and vae | 613 | # Freeze text_encoder and vae |
| 616 | freeze_params(vae.parameters()) | 614 | vae.requires_grad_(False) |
| 617 | 615 | ||
| 618 | if len(args.placeholder_token) != 0: | 616 | if len(args.placeholder_token) != 0: |
| 619 | print(f"Adding text embeddings: {args.placeholder_token}") | 617 | print(f"Adding text embeddings: {args.placeholder_token}") |
| @@ -629,6 +627,10 @@ def main(): | |||
| 629 | 627 | ||
| 630 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 628 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 631 | 629 | ||
| 630 | print(f"Token ID mappings:") | ||
| 631 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 632 | print(f"- {token_id} {token}") | ||
| 633 | |||
| 632 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 634 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 633 | text_encoder.resize_token_embeddings(len(tokenizer)) | 635 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 634 | 636 | ||
| @@ -8,7 +8,18 @@ from pathlib import Path | |||
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | from PIL import Image | 10 | from PIL import Image |
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler | 11 | from diffusers import ( |
| 12 | AutoencoderKL, | ||
| 13 | UNet2DConditionModel, | ||
| 14 | PNDMScheduler, | ||
| 15 | DPMSolverMultistepScheduler, | ||
| 16 | DPMSolverSinglestepScheduler, | ||
| 17 | DDIMScheduler, | ||
| 18 | LMSDiscreteScheduler, | ||
| 19 | EulerAncestralDiscreteScheduler, | ||
| 20 | KDPM2DiscreteScheduler, | ||
| 21 | KDPM2AncestralDiscreteScheduler | ||
| 22 | ) | ||
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 13 | from slugify import slugify | 24 | from slugify import slugify |
| 14 | 25 | ||
| @@ -29,7 +40,7 @@ default_args = { | |||
| 29 | 40 | ||
| 30 | 41 | ||
| 31 | default_cmds = { | 42 | default_cmds = { |
| 32 | "scheduler": "dpmpp", | 43 | "scheduler": "dpmsm", |
| 33 | "prompt": None, | 44 | "prompt": None, |
| 34 | "negative_prompt": None, | 45 | "negative_prompt": None, |
| 35 | "image": None, | 46 | "image": None, |
| @@ -38,7 +49,7 @@ default_cmds = { | |||
| 38 | "height": 512, | 49 | "height": 512, |
| 39 | "batch_size": 1, | 50 | "batch_size": 1, |
| 40 | "batch_num": 1, | 51 | "batch_num": 1, |
| 41 | "steps": 50, | 52 | "steps": 30, |
| 42 | "guidance_scale": 7.0, | 53 | "guidance_scale": 7.0, |
| 43 | "seed": None, | 54 | "seed": None, |
| 44 | "config": None, | 55 | "config": None, |
| @@ -90,7 +101,7 @@ def create_cmd_parser(): | |||
| 90 | parser.add_argument( | 101 | parser.add_argument( |
| 91 | "--scheduler", | 102 | "--scheduler", |
| 92 | type=str, | 103 | type=str, |
| 93 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | 104 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
| 94 | ) | 105 | ) |
| 95 | parser.add_argument( | 106 | parser.add_argument( |
| 96 | "--prompt", | 107 | "--prompt", |
| @@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): | |||
| 252 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) | 263 | pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 253 | elif args.scheduler == "ddim": | 264 | elif args.scheduler == "ddim": |
| 254 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | 265 | pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
| 255 | elif args.scheduler == "dpmpp": | 266 | elif args.scheduler == "dpmsm": |
| 256 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | 267 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 268 | elif args.scheduler == "dpmss": | ||
| 269 | pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) | ||
| 257 | elif args.scheduler == "euler_a": | 270 | elif args.scheduler == "euler_a": |
| 258 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 271 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 272 | elif args.scheduler == "kdpm2": | ||
| 273 | pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 274 | elif args.scheduler == "kdpm2_a": | ||
| 275 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | ||
| 259 | 276 | ||
| 260 | with torch.autocast("cuda"), torch.inference_mode(): | 277 | with torch.autocast("cuda"), torch.inference_mode(): |
| 261 | for i in range(args.batch_num): | 278 | for i in range(args.batch_num): |
