summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py29
-rw-r--r--dreambooth.py10
-rw-r--r--infer.py27
3 files changed, 48 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index 67ac43b..23b5299 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,6 +1,7 @@
1import math 1import math
2import torch 2import torch
3import json 3import json
4import numpy as np
4from pathlib import Path 5from pathlib import Path
5import pytorch_lightning as pl 6import pytorch_lightning as pl
6from PIL import Image 7from PIL import Image
@@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 16 return {"content": prompt} if isinstance(prompt, str) else prompt
16 17
17 18
19def shuffle_prompt(prompt: str):
20 def handle_block(block: str):
21 words = block.split(", ")
22 np.random.shuffle(words)
23 return ", ".join(words)
24
25 prompt = prompt.split(". ")
26 prompt = [handle_block(b) for b in prompt]
27 np.random.shuffle(prompt)
28 prompt = ". ".join(prompt)
29 return prompt
30
31
18class CSVDataItem(NamedTuple): 32class CSVDataItem(NamedTuple):
19 instance_image_path: Path 33 instance_image_path: Path
20 class_image_path: Path 34 class_image_path: Path
@@ -190,30 +204,27 @@ class CSVDataset(Dataset):
190 item = self.data[i % self.num_instance_images] 204 item = self.data[i % self.num_instance_images]
191 205
192 example = {} 206 example = {}
193
194 example["prompts"] = item.prompt 207 example["prompts"] = item.prompt
195 example["nprompts"] = item.nprompt 208 example["nprompts"] = item.nprompt
196
197 example["instance_images"] = self.get_image(item.instance_image_path) 209 example["instance_images"] = self.get_image(item.instance_image_path)
198 example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier)
199
200 if self.num_class_images != 0: 210 if self.num_class_images != 0:
201 example["class_images"] = self.get_image(item.class_image_path) 211 example["class_images"] = self.get_image(item.class_image_path)
202 example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier)
203 212
204 return example 213 return example
205 214
206 def __getitem__(self, i): 215 def __getitem__(self, i):
207 example = {}
208 unprocessed_example = self.get_example(i) 216 unprocessed_example = self.get_example(i)
209 217
210 example["prompts"] = unprocessed_example["prompts"] 218 example = {}
219
220 example["prompts"] = shuffle_prompt(unprocessed_example["prompts"])
211 example["nprompts"] = unprocessed_example["nprompts"] 221 example["nprompts"] = unprocessed_example["nprompts"]
222
212 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 223 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
213 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 224 example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier)
214 225
215 if self.num_class_images != 0: 226 if self.num_class_images != 0:
216 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 227 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
217 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 228 example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier)
218 229
219 return example 230 return example
diff --git a/dreambooth.py b/dreambooth.py
index ec9531e..0044c1e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,13 +1,11 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math 3import math
4import os
5import datetime 4import datetime
6import logging 5import logging
7import json 6import json
8from pathlib import Path 7from pathlib import Path
9 8
10import numpy as np
11import torch 9import torch
12import torch.nn.functional as F 10import torch.nn.functional as F
13import torch.utils.checkpoint 11import torch.utils.checkpoint
@@ -299,7 +297,7 @@ def parse_args():
299 parser.add_argument( 297 parser.add_argument(
300 "--sample_steps", 298 "--sample_steps",
301 type=int, 299 type=int,
302 default=20, 300 default=15,
303 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
304 ) 302 )
305 parser.add_argument( 303 parser.add_argument(
@@ -613,7 +611,7 @@ def main():
613 ) 611 )
614 612
615 # Freeze text_encoder and vae 613 # Freeze text_encoder and vae
616 freeze_params(vae.parameters()) 614 vae.requires_grad_(False)
617 615
618 if len(args.placeholder_token) != 0: 616 if len(args.placeholder_token) != 0:
619 print(f"Adding text embeddings: {args.placeholder_token}") 617 print(f"Adding text embeddings: {args.placeholder_token}")
@@ -629,6 +627,10 @@ def main():
629 627
630 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 628 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
631 629
630 print(f"Token ID mappings:")
631 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
632 print(f"- {token_id} {token}")
633
632 # Resize the token embeddings as we are adding new special tokens to the tokenizer 634 # Resize the token embeddings as we are adding new special tokens to the tokenizer
633 text_encoder.resize_token_embeddings(len(tokenizer)) 635 text_encoder.resize_token_embeddings(len(tokenizer))
634 636
diff --git a/infer.py b/infer.py
index 30e11cf..e3fa9e5 100644
--- a/infer.py
+++ b/infer.py
@@ -8,7 +8,18 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler 11from diffusers import (
12 AutoencoderKL,
13 UNet2DConditionModel,
14 PNDMScheduler,
15 DPMSolverMultistepScheduler,
16 DPMSolverSinglestepScheduler,
17 DDIMScheduler,
18 LMSDiscreteScheduler,
19 EulerAncestralDiscreteScheduler,
20 KDPM2DiscreteScheduler,
21 KDPM2AncestralDiscreteScheduler
22)
12from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 24from slugify import slugify
14 25
@@ -29,7 +40,7 @@ default_args = {
29 40
30 41
31default_cmds = { 42default_cmds = {
32 "scheduler": "dpmpp", 43 "scheduler": "dpmsm",
33 "prompt": None, 44 "prompt": None,
34 "negative_prompt": None, 45 "negative_prompt": None,
35 "image": None, 46 "image": None,
@@ -38,7 +49,7 @@ default_cmds = {
38 "height": 512, 49 "height": 512,
39 "batch_size": 1, 50 "batch_size": 1,
40 "batch_num": 1, 51 "batch_num": 1,
41 "steps": 50, 52 "steps": 30,
42 "guidance_scale": 7.0, 53 "guidance_scale": 7.0,
43 "seed": None, 54 "seed": None,
44 "config": None, 55 "config": None,
@@ -90,7 +101,7 @@ def create_cmd_parser():
90 parser.add_argument( 101 parser.add_argument(
91 "--scheduler", 102 "--scheduler",
92 type=str, 103 type=str,
93 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], 104 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
94 ) 105 )
95 parser.add_argument( 106 parser.add_argument(
96 "--prompt", 107 "--prompt",
@@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args):
252 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) 263 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
253 elif args.scheduler == "ddim": 264 elif args.scheduler == "ddim":
254 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) 265 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
255 elif args.scheduler == "dpmpp": 266 elif args.scheduler == "dpmsm":
256 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 267 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
268 elif args.scheduler == "dpmss":
269 pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)
257 elif args.scheduler == "euler_a": 270 elif args.scheduler == "euler_a":
258 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 271 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
272 elif args.scheduler == "kdpm2":
273 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
274 elif args.scheduler == "kdpm2_a":
275 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
259 276
260 with torch.autocast("cuda"), torch.inference_mode(): 277 with torch.autocast("cuda"), torch.inference_mode():
261 for i in range(args.batch_num): 278 for i in range(args.batch_num):