summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py14
-rw-r--r--dreambooth_plus.py44
-rw-r--r--infer.py12
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py96
4 files changed, 108 insertions, 58 deletions
diff --git a/data/csv.py b/data/csv.py
index aad970c..316c099 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule):
72 ] 72 ]
73 73
74 def prepare_data(self): 74 def prepare_data(self):
75 metadata = pd.read_csv(self.data_file) 75 metadata = pd.read_json(self.data_file)
76 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] 76 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True]
77 num_images = len(metadata) 77 num_images = len(metadata)
78 78
79 valid_set_size = int(num_images * 0.2) 79 valid_set_size = int(num_images * 0.2)
@@ -163,6 +163,12 @@ class CSVDataset(Dataset):
163 163
164 example = {} 164 example = {}
165 165
166 if isinstance(item.prompt, str):
167 item.prompt = [item.prompt]
168
169 if isinstance(item.nprompt, str):
170 item.nprompt = [item.nprompt]
171
166 example["prompts"] = item.prompt 172 example["prompts"] = item.prompt
167 example["nprompts"] = item.nprompt 173 example["nprompts"] = item.nprompt
168 174
@@ -177,7 +183,7 @@ class CSVDataset(Dataset):
177 example["instance_images"] = instance_image 183 example["instance_images"] = instance_image
178 example["instance_prompt_ids"] = self.tokenizer( 184 example["instance_prompt_ids"] = self.tokenizer(
179 item.prompt.format(self.instance_identifier), 185 item.prompt.format(self.instance_identifier),
180 padding="do_not_pad", 186 padding="max_length",
181 truncation=True, 187 truncation=True,
182 max_length=self.tokenizer.model_max_length, 188 max_length=self.tokenizer.model_max_length,
183 ).input_ids 189 ).input_ids
@@ -190,7 +196,7 @@ class CSVDataset(Dataset):
190 example["class_images"] = class_image 196 example["class_images"] = class_image
191 example["class_prompt_ids"] = self.tokenizer( 197 example["class_prompt_ids"] = self.tokenizer(
192 item.prompt.format(self.class_identifier), 198 item.prompt.format(self.class_identifier),
193 padding="do_not_pad", 199 padding="max_length",
194 truncation=True, 200 truncation=True,
195 max_length=self.tokenizer.model_max_length, 201 max_length=self.tokenizer.model_max_length,
196 ).input_ids 202 ).input_ids
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index a98417f..ae31377 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -124,7 +124,7 @@ def parse_args():
124 parser.add_argument( 124 parser.add_argument(
125 "--max_train_steps", 125 "--max_train_steps",
126 type=int, 126 type=int,
127 default=1500, 127 default=1400,
128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
129 ) 129 )
130 parser.add_argument( 130 parser.add_argument(
@@ -147,7 +147,7 @@ def parse_args():
147 parser.add_argument( 147 parser.add_argument(
148 "--learning_rate_text", 148 "--learning_rate_text",
149 type=float, 149 type=float,
150 default=5e-6, 150 default=1e-6,
151 help="Initial learning rate (after the potential warmup period) to use.", 151 help="Initial learning rate (after the potential warmup period) to use.",
152 ) 152 )
153 parser.add_argument( 153 parser.add_argument(
@@ -469,9 +469,16 @@ class Checkpointer:
469 469
470 for i in range(self.sample_batches): 470 for i in range(self.sample_batches):
471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
472 prompt = [prompt.format(self.instance_identifier) 472 prompt = [
473 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 473 [p.format(self.instance_identifier) for p in prompt]
474 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 474 for batch in batches
475 for prompt in batch["prompts"]
476 ][:self.sample_batch_size]
477 nprompt = [
478 prompt
479 for batch in batches
480 for prompt in batch["nprompts"]
481 ][:self.sample_batch_size]
475 482
476 samples = pipeline( 483 samples = pipeline(
477 prompt=prompt, 484 prompt=prompt,
@@ -666,6 +673,17 @@ def main():
666 } 673 }
667 return batch 674 return batch
668 675
676 def encode_input_ids(input_ids):
677 text_embeddings = []
678
679 for ids in input_ids:
680 embeddings = text_encoder(ids)[0]
681 embeddings = embeddings.reshape((1, -1, 768))
682 text_embeddings.append(embeddings)
683
684 text_embeddings = torch.cat(text_embeddings)
685 return text_embeddings
686
669 datamodule = CSVDataModule( 687 datamodule = CSVDataModule(
670 data_file=args.train_data_file, 688 data_file=args.train_data_file,
671 batch_size=args.train_batch_size, 689 batch_size=args.train_batch_size,
@@ -688,8 +706,10 @@ def main():
688 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] 706 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
689 707
690 if len(missing_data) != 0: 708 if len(missing_data) != 0:
691 batched_data = [missing_data[i:i+args.sample_batch_size] 709 batched_data = [
692 for i in range(0, len(missing_data), args.sample_batch_size)] 710 missing_data[i:i+args.sample_batch_size]
711 for i in range(0, len(missing_data), args.sample_batch_size)
712 ]
693 713
694 scheduler = EulerAScheduler( 714 scheduler = EulerAScheduler(
695 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 715 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -706,9 +726,9 @@ def main():
706 726
707 with torch.inference_mode(): 727 with torch.inference_mode():
708 for batch in batched_data: 728 for batch in batched_data:
709 image_name = [p.class_image_path for p in batch] 729 image_name = [item.class_image_path for item in batch]
710 prompt = [p.prompt.format(args.class_identifier) for p in batch] 730 prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch]
711 nprompt = [p.nprompt for p in batch] 731 nprompt = [item.nprompt for item in batch]
712 732
713 images = pipeline( 733 images = pipeline(
714 prompt=prompt, 734 prompt=prompt,
@@ -855,7 +875,7 @@ def main():
855 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 875 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
856 876
857 # Get the text embedding for conditioning 877 # Get the text embedding for conditioning
858 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 878 encoder_hidden_states = encode_input_ids(batch["input_ids"])
859 879
860 # Predict the noise residual 880 # Predict the noise residual
861 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 881 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -954,7 +974,7 @@ def main():
954 974
955 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 975 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
956 976
957 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 977 encoder_hidden_states = encode_input_ids(batch["input_ids"])
958 978
959 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 979 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
960 980
diff --git a/infer.py b/infer.py
index 1a0baf5..d744768 100644
--- a/infer.py
+++ b/infer.py
@@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19torch.backends.cuda.matmul.allow_tf32 = True 19torch.backends.cuda.matmul.allow_tf32 = True
20 20
21 21
22line_sep = " <OR> "
23
24
22default_args = { 25default_args = {
23 "model": None, 26 "model": None,
24 "scheduler": "euler_a", 27 "scheduler": "euler_a",
@@ -95,10 +98,12 @@ def create_cmd_parser():
95 parser.add_argument( 98 parser.add_argument(
96 "--prompt", 99 "--prompt",
97 type=str, 100 type=str,
101 nargs="+",
98 ) 102 )
99 parser.add_argument( 103 parser.add_argument(
100 "--negative_prompt", 104 "--negative_prompt",
101 type=str, 105 type=str,
106 nargs="*",
102 ) 107 )
103 parser.add_argument( 108 parser.add_argument(
104 "--image", 109 "--image",
@@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args):
271 dynamic_ncols=True 276 dynamic_ncols=True
272 ) 277 )
273 278
279 if isinstance(args.prompt, str):
280 args.prompt = [args.prompt]
281
282 prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size
283
274 generator = torch.Generator(device="cuda").manual_seed(seed + i) 284 generator = torch.Generator(device="cuda").manual_seed(seed + i)
275 images = pipeline( 285 images = pipeline(
276 prompt=[args.prompt] * args.batch_size, 286 prompt=prompt,
277 height=args.height, 287 height=args.height,
278 width=args.width, 288 width=args.width,
279 negative_prompt=args.negative_prompt, 289 negative_prompt=args.negative_prompt,
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8b08a6f..b68b028 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -24,6 +24,22 @@ def preprocess(image, w, h):
24 return 2.0 * image - 1.0 24 return 2.0 * image - 1.0
25 25
26 26
27def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None):
28 if isinstance(prompt, str):
29 prompt = [prompt] * batch_size
30
31 if isinstance(prompt, list) and isinstance(prompt[0], str):
32 prompt = [[p] for p in prompt]
33
34 if isinstance(prompt, list) and isinstance(prompt[0], list):
35 prompt_size = prompt_size or max([len(p) for p in prompt])
36 prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt]
37 else:
38 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
39
40 return prompt_size, prompt
41
42
27class VlpnStableDiffusion(DiffusionPipeline): 43class VlpnStableDiffusion(DiffusionPipeline):
28 def __init__( 44 def __init__(
29 self, 45 self,
@@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline):
85 # set slice_size = `None` to disable `attention slicing` 101 # set slice_size = `None` to disable `attention slicing`
86 self.enable_attention_slicing(None) 102 self.enable_attention_slicing(None)
87 103
104 def embeddings_for_prompt(self, prompt: List[List[str]]):
105 text_embeddings = []
106
107 for p in prompt:
108 inputs = self.tokenizer(
109 p,
110 padding="max_length",
111 max_length=self.tokenizer.model_max_length,
112 return_tensors="pt",
113 )
114 input_ids = inputs.input_ids
115
116 if input_ids.shape[-1] > self.tokenizer.model_max_length:
117 removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:])
118 logger.warning(
119 "The following part of your input was truncated because CLIP can only handle sequences up to"
120 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
121 )
122 print(f"Too many tokens: {removed_text}")
123 input_ids = input_ids[:, : self.tokenizer.model_max_length]
124
125 embeddings = self.text_encoder(input_ids.to(self.device))[0]
126 embeddings = embeddings.reshape((1, -1, 768))
127 text_embeddings.append(embeddings)
128
129 text_embeddings = torch.cat(text_embeddings)
130 return text_embeddings
131
88 @torch.no_grad() 132 @torch.no_grad()
89 def __call__( 133 def __call__(
90 self, 134 self,
91 prompt: Union[str, List[str]], 135 prompt: Union[str, List[str], List[List[str]]],
92 negative_prompt: Optional[Union[str, List[str]]] = None, 136 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None,
93 strength: float = 0.8, 137 strength: float = 0.8,
94 height: Optional[int] = 512, 138 height: Optional[int] = 512,
95 width: Optional[int] = 512, 139 width: Optional[int] = 512,
@@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
151 (nsfw) content, according to the `safety_checker`. 195 (nsfw) content, according to the `safety_checker`.
152 """ 196 """
153 197
154 if isinstance(prompt, str): 198 prompt_size, prompt = normalize_prompt(prompt)
155 batch_size = 1 199 batch_size = len(prompt)
156 elif isinstance(prompt, list): 200 _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size)
157 batch_size = len(prompt) 201
158 else: 202 if len(negative_prompt) != batch_size:
159 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 203 raise ValueError(
160 204 f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}")
161 if negative_prompt is None:
162 negative_prompt = [""] * batch_size
163 elif isinstance(negative_prompt, str):
164 negative_prompt = [negative_prompt] * batch_size
165 elif isinstance(negative_prompt, list):
166 if len(negative_prompt) != batch_size:
167 raise ValueError(
168 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
169 else:
170 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
171 205
172 if height % 8 != 0 or width % 8 != 0: 206 if height % 8 != 0 or width % 8 != 0:
173 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 207 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
179 self.scheduler.set_timesteps(num_inference_steps) 213 self.scheduler.set_timesteps(num_inference_steps)
180 214
181 # get prompt text embeddings 215 # get prompt text embeddings
182 text_inputs = self.tokenizer( 216 text_embeddings = self.embeddings_for_prompt(prompt)
183 prompt,
184 padding="max_length",
185 max_length=self.tokenizer.model_max_length,
186 return_tensors="pt",
187 )
188 text_input_ids = text_inputs.input_ids
189
190 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
191 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
192 logger.warning(
193 "The following part of your input was truncated because CLIP can only handle sequences up to"
194 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
195 )
196 print(f"Too many tokens: {removed_text}")
197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
199 217
200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 218 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 219 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
203 do_classifier_free_guidance = guidance_scale > 1.0 221 do_classifier_free_guidance = guidance_scale > 1.0
204 # get unconditional embeddings for classifier free guidance 222 # get unconditional embeddings for classifier free guidance
205 if do_classifier_free_guidance: 223 if do_classifier_free_guidance:
206 max_length = text_input_ids.shape[-1] 224 uncond_embeddings = self.embeddings_for_prompt(negative_prompt)
207 uncond_input = self.tokenizer(
208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
209 )
210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
211 225
212 # For classifier free guidance, we need to do two forward passes. 226 # For classifier free guidance, we need to do two forward passes.
213 # Here we concatenate the unconditional and text embeddings into a single batch 227 # Here we concatenate the unconditional and text embeddings into a single batch