summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py21
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py61
-rw-r--r--scripts/convert_diffusers_to_original_stable_diffusion.py234
-rw-r--r--scripts/convert_original_stable_diffusion_to_diffusers.py690
-rw-r--r--train_dreambooth.py14
-rw-r--r--train_lora.py16
-rw-r--r--train_ti.py14
-rw-r--r--training/functional.py36
8 files changed, 99 insertions, 987 deletions
diff --git a/data/csv.py b/data/csv.py
index fba5d4b..a6cd065 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -99,14 +99,16 @@ def generate_buckets(
99 return buckets, bucket_items, bucket_assignments 99 return buckets, bucket_items, bucket_assignments
100 100
101 101
102def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): 102def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples):
103 prompt_ids = [example["prompt_ids"] for example in examples] 103 prompt_ids = [example["prompt_ids"] for example in examples]
104 nprompt_ids = [example["nprompt_ids"] for example in examples] 104 nprompt_ids = [example["nprompt_ids"] for example in examples]
105 105
106 input_ids = [example["instance_prompt_ids"] for example in examples] 106 input_ids = [example["instance_prompt_ids"] for example in examples]
107 pixel_values = [example["instance_images"] for example in examples] 107 pixel_values = [example["instance_images"] for example in examples]
108 108
109 if with_prior_preservation: 109 if with_guidance:
110 input_ids += [example["negative_prompt_ids"] for example in examples]
111 elif with_prior_preservation:
110 input_ids += [example["class_prompt_ids"] for example in examples] 112 input_ids += [example["class_prompt_ids"] for example in examples]
111 pixel_values += [example["class_images"] for example in examples] 113 pixel_values += [example["class_images"] for example in examples]
112 114
@@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple):
133 class_image_path: Path 135 class_image_path: Path
134 prompt: list[str] 136 prompt: list[str]
135 cprompt: str 137 cprompt: str
136 nprompt: str 138 nprompt: list[str]
137 collection: list[str] 139 collection: list[str]
138 140
139 141
@@ -163,6 +165,7 @@ class VlpnDataModule():
163 data_file: str, 165 data_file: str,
164 tokenizer: CLIPTokenizer, 166 tokenizer: CLIPTokenizer,
165 class_subdir: str = "cls", 167 class_subdir: str = "cls",
168 with_guidance: bool = False,
166 num_class_images: int = 1, 169 num_class_images: int = 1,
167 size: int = 768, 170 size: int = 768,
168 num_buckets: int = 0, 171 num_buckets: int = 0,
@@ -191,6 +194,7 @@ class VlpnDataModule():
191 self.class_root = self.data_root / class_subdir 194 self.class_root = self.data_root / class_subdir
192 self.class_root.mkdir(parents=True, exist_ok=True) 195 self.class_root.mkdir(parents=True, exist_ok=True)
193 self.num_class_images = num_class_images 196 self.num_class_images = num_class_images
197 self.with_guidance = with_guidance
194 198
195 self.tokenizer = tokenizer 199 self.tokenizer = tokenizer
196 self.size = size 200 self.size = size
@@ -228,10 +232,10 @@ class VlpnDataModule():
228 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 232 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
229 expansions 233 expansions
230 )), 234 )),
231 keywords_to_prompt(prompt_to_keywords( 235 prompt_to_keywords(
232 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 236 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
233 expansions 237 expansions
234 )), 238 ),
235 item["collection"].split(", ") if "collection" in item else [] 239 item["collection"].split(", ") if "collection" in item else []
236 ) 240 )
237 for item in data 241 for item in data
@@ -279,7 +283,7 @@ class VlpnDataModule():
279 if self.seed is not None: 283 if self.seed is not None:
280 generator = generator.manual_seed(self.seed) 284 generator = generator.manual_seed(self.seed)
281 285
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 286 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0)
283 287
284 if valid_set_size == 0: 288 if valid_set_size == 0:
285 data_train, data_val = items, items 289 data_train, data_val = items, items
@@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset):
443 example = {} 447 example = {}
444 448
445 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) 449 example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt))
446 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 450 example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt))
447 451
448 example["instance_prompt_ids"] = self.get_input_ids( 452 example["instance_prompt_ids"] = self.get_input_ids(
449 keywords_to_prompt(item.prompt, self.dropout, True) 453 keywords_to_prompt(item.prompt, self.dropout, True)
450 ) 454 )
455 example["negative_prompt_ids"] = self.get_input_ids(
456 keywords_to_prompt(item.nprompt, self.dropout, True)
457 )
451 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 458 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
452 459
453 if self.num_class_images != 0: 460 if self.num_class_images != 0:
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index ea2a656..127ca50 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline):
307 307
308 return timesteps, num_inference_steps - t_start 308 return timesteps, num_inference_steps - t_start
309 309
310 def prepare_image(self, batch_size, width, height, dtype, device, generator=None): 310 def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None):
311 return (1.4 * perlin_noise( 311 offset_image = perlin_noise(
312 (batch_size, 1, width, height), 312 (batch_size, 1, width, height),
313 res=1, 313 res=1,
314 octaves=4,
315 generator=generator, 314 generator=generator,
316 dtype=dtype, 315 dtype=dtype,
317 device=device 316 device=device
318 )).clamp(-1, 1).expand(batch_size, 3, width, height) 317 )
318 offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator)
319 offset_latents = self.vae.config.scaling_factor * offset_latents
320 return offset_latents
319 321
320 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): 322 def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None):
321 init_image = init_image.to(device=device, dtype=dtype) 323 init_image = init_image.to(device=device, dtype=dtype)
322 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 324 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
323 init_latents = self.vae.config.scaling_factor * init_latents 325 latents = self.vae.config.scaling_factor * latents
324 326
325 if batch_size % init_latents.shape[0] != 0: 327 if batch_size % latents.shape[0] != 0:
326 raise ValueError( 328 raise ValueError(
327 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 329 f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts."
328 ) 330 )
329 else: 331 else:
330 batch_multiplier = batch_size // init_latents.shape[0] 332 batch_multiplier = batch_size // latents.shape[0]
331 init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) 333 latents = torch.cat([latents] * batch_multiplier, dim=0)
332 334
333 # add noise to latents using the timesteps 335 # add noise to latents using the timesteps
334 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 336 noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype)
337
338 if brightness_offset != 0:
339 noise += brightness_offset * self.prepare_brightness_offset(
340 batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator
341 )
335 342
336 # get latents 343 # get latents
337 init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 344 latents = self.scheduler.add_noise(latents, noise, timestep)
338 latents = init_latents
339 345
340 return latents 346 return latents
341 347
342 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 348 def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None):
343 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 349 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
344 if isinstance(generator, list) and len(generator) != batch_size: 350 if isinstance(generator, list) and len(generator) != batch_size:
345 raise ValueError( 351 raise ValueError(
@@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
352 else: 358 else:
353 latents = latents.to(device) 359 latents = latents.to(device)
354 360
361 if brightness_offset != 0:
362 latents += brightness_offset * self.prepare_brightness_offset(
363 batch_size, height, width, dtype, device, generator
364 )
365
355 # scale the initial noise by the standard deviation required by the scheduler 366 # scale the initial noise by the standard deviation required by the scheduler
356 latents = latents * self.scheduler.init_noise_sigma 367 latents = latents * self.scheduler.init_noise_sigma
357 return latents 368 return latents
@@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
395 sag_scale: float = 0.75, 406 sag_scale: float = 0.75,
396 eta: float = 0.0, 407 eta: float = 0.0,
397 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
398 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, 409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
410 brightness_offset: Union[float, torch.FloatTensor] = 0,
399 output_type: str = "pil", 411 output_type: str = "pil",
400 return_dict: bool = True, 412 return_dict: bool = True,
401 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 413 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
468 num_channels_latents = self.unet.in_channels 480 num_channels_latents = self.unet.in_channels
469 do_classifier_free_guidance = guidance_scale > 1.0 481 do_classifier_free_guidance = guidance_scale > 1.0
470 do_self_attention_guidance = sag_scale > 0.0 482 do_self_attention_guidance = sag_scale > 0.0
471 prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" 483 prep_from_image = isinstance(image, PIL.Image.Image)
472 484
473 # 3. Encode input prompt 485 # 3. Encode input prompt
474 prompt_embeds = self.encode_prompt( 486 prompt_embeds = self.encode_prompt(
@@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
482 # 4. Prepare latent variables 494 # 4. Prepare latent variables
483 if isinstance(image, PIL.Image.Image): 495 if isinstance(image, PIL.Image.Image):
484 image = preprocess(image) 496 image = preprocess(image)
485 elif image == "noise":
486 image = self.prepare_image(
487 batch_size * num_images_per_prompt,
488 width,
489 height,
490 prompt_embeds.dtype,
491 device,
492 generator
493 )
494 497
495 # 5. Prepare timesteps 498 # 5. Prepare timesteps
496 self.scheduler.set_timesteps(num_inference_steps, device=device) 499 self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
503 image, 506 image,
504 latent_timestep, 507 latent_timestep,
505 batch_size * num_images_per_prompt, 508 batch_size * num_images_per_prompt,
509 brightness_offset,
506 prompt_embeds.dtype, 510 prompt_embeds.dtype,
507 device, 511 device,
508 generator 512 generator,
509 ) 513 )
510 else: 514 else:
511 latents = self.prepare_latents( 515 latents = self.prepare_latents(
@@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
513 num_channels_latents, 517 num_channels_latents,
514 height, 518 height,
515 width, 519 width,
520 brightness_offset,
516 prompt_embeds.dtype, 521 prompt_embeds.dtype,
517 device, 522 device,
518 generator, 523 generator,
519 image 524 image,
520 ) 525 )
521 526
522 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 527 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
deleted file mode 100644
index 9888f62..0000000
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ /dev/null
@@ -1,234 +0,0 @@
1# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2# *Only* converts the UNet, VAE, and Text Encoder.
3# Does not convert optimizer state or any other thing.
4
5import argparse
6import os.path as osp
7
8import torch
9
10
11# =================#
12# UNet Conversion #
13# =================#
14
15unet_conversion_map = [
16 # (stable-diffusion, HF Diffusers)
17 ("time_embed.0.weight", "time_embedding.linear_1.weight"),
18 ("time_embed.0.bias", "time_embedding.linear_1.bias"),
19 ("time_embed.2.weight", "time_embedding.linear_2.weight"),
20 ("time_embed.2.bias", "time_embedding.linear_2.bias"),
21 ("input_blocks.0.0.weight", "conv_in.weight"),
22 ("input_blocks.0.0.bias", "conv_in.bias"),
23 ("out.0.weight", "conv_norm_out.weight"),
24 ("out.0.bias", "conv_norm_out.bias"),
25 ("out.2.weight", "conv_out.weight"),
26 ("out.2.bias", "conv_out.bias"),
27]
28
29unet_conversion_map_resnet = [
30 # (stable-diffusion, HF Diffusers)
31 ("in_layers.0", "norm1"),
32 ("in_layers.2", "conv1"),
33 ("out_layers.0", "norm2"),
34 ("out_layers.3", "conv2"),
35 ("emb_layers.1", "time_emb_proj"),
36 ("skip_connection", "conv_shortcut"),
37]
38
39unet_conversion_map_layer = []
40# hardcoded number of downblocks and resnets/attentions...
41# would need smarter logic for other networks.
42for i in range(4):
43 # loop over downblocks/upblocks
44
45 for j in range(2):
46 # loop over resnets/attentions for downblocks
47 hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48 sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49 unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50
51 if i < 3:
52 # no attention layers in down_blocks.3
53 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54 sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55 unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56
57 for j in range(3):
58 # loop over resnets/attentions for upblocks
59 hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60 sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61 unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62
63 if i > 0:
64 # no attention layers in up_blocks.0
65 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66 sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67 unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68
69 if i < 3:
70 # no downsample in down_blocks.3
71 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72 sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73 unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74
75 # no upsample in up_blocks.3
76 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77 sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78 unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79
80hf_mid_atn_prefix = "mid_block.attentions.0."
81sd_mid_atn_prefix = "middle_block.1."
82unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83
84for j in range(2):
85 hf_mid_res_prefix = f"mid_block.resnets.{j}."
86 sd_mid_res_prefix = f"middle_block.{2*j}."
87 unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88
89
90def convert_unet_state_dict(unet_state_dict):
91 # buyer beware: this is a *brittle* function,
92 # and correct output requires that all of these pieces interact in
93 # the exact order in which I have arranged them.
94 mapping = {k: k for k in unet_state_dict.keys()}
95 for sd_name, hf_name in unet_conversion_map:
96 mapping[hf_name] = sd_name
97 for k, v in mapping.items():
98 if "resnets" in k:
99 for sd_part, hf_part in unet_conversion_map_resnet:
100 v = v.replace(hf_part, sd_part)
101 mapping[k] = v
102 for k, v in mapping.items():
103 for sd_part, hf_part in unet_conversion_map_layer:
104 v = v.replace(hf_part, sd_part)
105 mapping[k] = v
106 new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107 return new_state_dict
108
109
110# ================#
111# VAE Conversion #
112# ================#
113
114vae_conversion_map = [
115 # (stable-diffusion, HF Diffusers)
116 ("nin_shortcut", "conv_shortcut"),
117 ("norm_out", "conv_norm_out"),
118 ("mid.attn_1.", "mid_block.attentions.0."),
119]
120
121for i in range(4):
122 # down_blocks have two resnets
123 for j in range(2):
124 hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125 sd_down_prefix = f"encoder.down.{i}.block.{j}."
126 vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127
128 if i < 3:
129 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130 sd_downsample_prefix = f"down.{i}.downsample."
131 vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132
133 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134 sd_upsample_prefix = f"up.{3-i}.upsample."
135 vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136
137 # up_blocks have three resnets
138 # also, up blocks in hf are numbered in reverse from sd
139 for j in range(3):
140 hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141 sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142 vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143
144# this part accounts for mid blocks in both the encoder and the decoder
145for i in range(2):
146 hf_mid_res_prefix = f"mid_block.resnets.{i}."
147 sd_mid_res_prefix = f"mid.block_{i+1}."
148 vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149
150
151vae_conversion_map_attn = [
152 # (stable-diffusion, HF Diffusers)
153 ("norm.", "group_norm."),
154 ("q.", "query."),
155 ("k.", "key."),
156 ("v.", "value."),
157 ("proj_out.", "proj_attn."),
158]
159
160
161def reshape_weight_for_sd(w):
162 # convert HF linear weights to SD conv2d weights
163 return w.reshape(*w.shape, 1, 1)
164
165
166def convert_vae_state_dict(vae_state_dict):
167 mapping = {k: k for k in vae_state_dict.keys()}
168 for k, v in mapping.items():
169 for sd_part, hf_part in vae_conversion_map:
170 v = v.replace(hf_part, sd_part)
171 mapping[k] = v
172 for k, v in mapping.items():
173 if "attentions" in k:
174 for sd_part, hf_part in vae_conversion_map_attn:
175 v = v.replace(hf_part, sd_part)
176 mapping[k] = v
177 new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178 weights_to_convert = ["q", "k", "v", "proj_out"]
179 for k, v in new_state_dict.items():
180 for weight_name in weights_to_convert:
181 if f"mid.attn_1.{weight_name}.weight" in k:
182 print(f"Reshaping {k} for SD format")
183 new_state_dict[k] = reshape_weight_for_sd(v)
184 return new_state_dict
185
186
187# =========================#
188# Text Encoder Conversion #
189# =========================#
190# pretty much a no-op
191
192
193def convert_text_enc_state_dict(text_enc_dict):
194 return text_enc_dict
195
196
197if __name__ == "__main__":
198 parser = argparse.ArgumentParser()
199
200 parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
201 parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
202 parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203
204 args = parser.parse_args()
205
206 assert args.model_path is not None, "Must provide a model path!"
207
208 assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209
210 unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211 vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212 text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213
214 # Convert the UNet model
215 unet_state_dict = torch.load(unet_path, map_location="cpu")
216 unet_state_dict = convert_unet_state_dict(unet_state_dict)
217 unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218
219 # Convert the VAE model
220 vae_state_dict = torch.load(vae_path, map_location="cpu")
221 vae_state_dict = convert_vae_state_dict(vae_state_dict)
222 vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223
224 # Convert the text encoder model
225 text_enc_dict = torch.load(text_enc_path, map_location="cpu")
226 text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227 text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
229 # Put together new checkpoint
230 state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231 if args.half:
232 state_dict = {k: v.half() for k, v in state_dict.items()}
233 state_dict = {"state_dict": state_dict}
234 torch.save(state_dict, args.checkpoint_path)
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py
deleted file mode 100644
index ee7fc33..0000000
--- a/scripts/convert_original_stable_diffusion_to_diffusers.py
+++ /dev/null
@@ -1,690 +0,0 @@
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Conversion script for the LDM checkpoints. """
16
17import argparse
18import os
19
20import torch
21
22
23try:
24 from omegaconf import OmegaConf
25except ImportError:
26 raise ImportError(
27 "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
28 )
29
30from diffusers import (
31 AutoencoderKL,
32 DDIMScheduler,
33 LDMTextToImagePipeline,
34 LMSDiscreteScheduler,
35 PNDMScheduler,
36 StableDiffusionPipeline,
37 UNet2DConditionModel,
38)
39from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
40from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
41from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
42
43
44def shave_segments(path, n_shave_prefix_segments=1):
45 """
46 Removes segments. Positive values shave the first segments, negative shave the last segments.
47 """
48 if n_shave_prefix_segments >= 0:
49 return ".".join(path.split(".")[n_shave_prefix_segments:])
50 else:
51 return ".".join(path.split(".")[:n_shave_prefix_segments])
52
53
54def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
55 """
56 Updates paths inside resnets to the new naming scheme (local renaming)
57 """
58 mapping = []
59 for old_item in old_list:
60 new_item = old_item.replace("in_layers.0", "norm1")
61 new_item = new_item.replace("in_layers.2", "conv1")
62
63 new_item = new_item.replace("out_layers.0", "norm2")
64 new_item = new_item.replace("out_layers.3", "conv2")
65
66 new_item = new_item.replace("emb_layers.1", "time_emb_proj")
67 new_item = new_item.replace("skip_connection", "conv_shortcut")
68
69 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
70
71 mapping.append({"old": old_item, "new": new_item})
72
73 return mapping
74
75
76def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
77 """
78 Updates paths inside resnets to the new naming scheme (local renaming)
79 """
80 mapping = []
81 for old_item in old_list:
82 new_item = old_item
83
84 new_item = new_item.replace("nin_shortcut", "conv_shortcut")
85 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
86
87 mapping.append({"old": old_item, "new": new_item})
88
89 return mapping
90
91
92def renew_attention_paths(old_list, n_shave_prefix_segments=0):
93 """
94 Updates paths inside attentions to the new naming scheme (local renaming)
95 """
96 mapping = []
97 for old_item in old_list:
98 new_item = old_item
99
100 # new_item = new_item.replace('norm.weight', 'group_norm.weight')
101 # new_item = new_item.replace('norm.bias', 'group_norm.bias')
102
103 # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
104 # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
105
106 # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
107
108 mapping.append({"old": old_item, "new": new_item})
109
110 return mapping
111
112
113def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
114 """
115 Updates paths inside attentions to the new naming scheme (local renaming)
116 """
117 mapping = []
118 for old_item in old_list:
119 new_item = old_item
120
121 new_item = new_item.replace("norm.weight", "group_norm.weight")
122 new_item = new_item.replace("norm.bias", "group_norm.bias")
123
124 new_item = new_item.replace("q.weight", "query.weight")
125 new_item = new_item.replace("q.bias", "query.bias")
126
127 new_item = new_item.replace("k.weight", "key.weight")
128 new_item = new_item.replace("k.bias", "key.bias")
129
130 new_item = new_item.replace("v.weight", "value.weight")
131 new_item = new_item.replace("v.bias", "value.bias")
132
133 new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
134 new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
135
136 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
137
138 mapping.append({"old": old_item, "new": new_item})
139
140 return mapping
141
142
143def assign_to_checkpoint(
144 paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
145):
146 """
147 This does the final conversion step: take locally converted weights and apply a global renaming
148 to them. It splits attention layers, and takes into account additional replacements
149 that may arise.
150
151 Assigns the weights to the new checkpoint.
152 """
153 assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
154
155 # Splits the attention layers into three variables.
156 if attention_paths_to_split is not None:
157 for path, path_map in attention_paths_to_split.items():
158 old_tensor = old_checkpoint[path]
159 channels = old_tensor.shape[0] // 3
160
161 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
162
163 num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
164
165 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
166 query, key, value = old_tensor.split(channels // num_heads, dim=1)
167
168 checkpoint[path_map["query"]] = query.reshape(target_shape)
169 checkpoint[path_map["key"]] = key.reshape(target_shape)
170 checkpoint[path_map["value"]] = value.reshape(target_shape)
171
172 for path in paths:
173 new_path = path["new"]
174
175 # These have already been assigned
176 if attention_paths_to_split is not None and new_path in attention_paths_to_split:
177 continue
178
179 # Global renaming happens here
180 new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
181 new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
182 new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
183
184 if additional_replacements is not None:
185 for replacement in additional_replacements:
186 new_path = new_path.replace(replacement["old"], replacement["new"])
187
188 # proj_attn.weight has to be converted from conv 1D to linear
189 if "proj_attn.weight" in new_path:
190 checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
191 else:
192 checkpoint[new_path] = old_checkpoint[path["old"]]
193
194
195def conv_attn_to_linear(checkpoint):
196 keys = list(checkpoint.keys())
197 attn_keys = ["query.weight", "key.weight", "value.weight"]
198 for key in keys:
199 if ".".join(key.split(".")[-2:]) in attn_keys:
200 if checkpoint[key].ndim > 2:
201 checkpoint[key] = checkpoint[key][:, :, 0, 0]
202 elif "proj_attn.weight" in key:
203 if checkpoint[key].ndim > 2:
204 checkpoint[key] = checkpoint[key][:, :, 0]
205
206
207def create_unet_diffusers_config(original_config):
208 """
209 Creates a config for the diffusers based on the config of the LDM model.
210 """
211 unet_params = original_config.model.params.unet_config.params
212
213 block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
214
215 down_block_types = []
216 resolution = 1
217 for i in range(len(block_out_channels)):
218 block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
219 down_block_types.append(block_type)
220 if i != len(block_out_channels) - 1:
221 resolution *= 2
222
223 up_block_types = []
224 for i in range(len(block_out_channels)):
225 block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
226 up_block_types.append(block_type)
227 resolution //= 2
228
229 config = dict(
230 sample_size=unet_params.image_size,
231 in_channels=unet_params.in_channels,
232 out_channels=unet_params.out_channels,
233 down_block_types=tuple(down_block_types),
234 up_block_types=tuple(up_block_types),
235 block_out_channels=tuple(block_out_channels),
236 layers_per_block=unet_params.num_res_blocks,
237 cross_attention_dim=unet_params.context_dim,
238 attention_head_dim=unet_params.num_heads,
239 )
240
241 return config
242
243
244def create_vae_diffusers_config(original_config):
245 """
246 Creates a config for the diffusers based on the config of the LDM model.
247 """
248 vae_params = original_config.model.params.first_stage_config.params.ddconfig
249 _ = original_config.model.params.first_stage_config.params.embed_dim
250
251 block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
252 down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
253 up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
254
255 config = dict(
256 sample_size=vae_params.resolution,
257 in_channels=vae_params.in_channels,
258 out_channels=vae_params.out_ch,
259 down_block_types=tuple(down_block_types),
260 up_block_types=tuple(up_block_types),
261 block_out_channels=tuple(block_out_channels),
262 latent_channels=vae_params.z_channels,
263 layers_per_block=vae_params.num_res_blocks,
264 )
265 return config
266
267
268def create_diffusers_schedular(original_config):
269 schedular = DDIMScheduler(
270 num_train_timesteps=original_config.model.params.timesteps,
271 beta_start=original_config.model.params.linear_start,
272 beta_end=original_config.model.params.linear_end,
273 beta_schedule="scaled_linear",
274 )
275 return schedular
276
277
278def create_ldm_bert_config(original_config):
279 bert_params = original_config.model.parms.cond_stage_config.params
280 config = LDMBertConfig(
281 d_model=bert_params.n_embed,
282 encoder_layers=bert_params.n_layer,
283 encoder_ffn_dim=bert_params.n_embed * 4,
284 )
285 return config
286
287
288def convert_ldm_unet_checkpoint(checkpoint, config):
289 """
290 Takes a state dict and a config, and returns a converted checkpoint.
291 """
292
293 # extract state_dict for UNet
294 unet_state_dict = {}
295 unet_key = "model.diffusion_model."
296 keys = list(checkpoint.keys())
297 for key in keys:
298 if key.startswith(unet_key):
299 unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
300
301 new_checkpoint = {}
302
303 new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
304 new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
305 new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
306 new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
307
308 new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
309 new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
310
311 new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
312 new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
313 new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
314 new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
315
316 # Retrieves the keys for the input blocks only
317 num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
318 input_blocks = {
319 layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
320 for layer_id in range(num_input_blocks)
321 }
322
323 # Retrieves the keys for the middle blocks only
324 num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
325 middle_blocks = {
326 layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
327 for layer_id in range(num_middle_blocks)
328 }
329
330 # Retrieves the keys for the output blocks only
331 num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
332 output_blocks = {
333 layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
334 for layer_id in range(num_output_blocks)
335 }
336
337 for i in range(1, num_input_blocks):
338 block_id = (i - 1) // (config["layers_per_block"] + 1)
339 layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
340
341 resnets = [
342 key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
343 ]
344 attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
345
346 if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
347 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
348 f"input_blocks.{i}.0.op.weight"
349 )
350 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
351 f"input_blocks.{i}.0.op.bias"
352 )
353
354 paths = renew_resnet_paths(resnets)
355 meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
356 assign_to_checkpoint(
357 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
358 )
359
360 if len(attentions):
361 paths = renew_attention_paths(attentions)
362 meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
363 assign_to_checkpoint(
364 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
365 )
366
367 resnet_0 = middle_blocks[0]
368 attentions = middle_blocks[1]
369 resnet_1 = middle_blocks[2]
370
371 resnet_0_paths = renew_resnet_paths(resnet_0)
372 assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
373
374 resnet_1_paths = renew_resnet_paths(resnet_1)
375 assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
376
377 attentions_paths = renew_attention_paths(attentions)
378 meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
379 assign_to_checkpoint(
380 attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
381 )
382
383 for i in range(num_output_blocks):
384 block_id = i // (config["layers_per_block"] + 1)
385 layer_in_block_id = i % (config["layers_per_block"] + 1)
386 output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
387 output_block_list = {}
388
389 for layer in output_block_layers:
390 layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
391 if layer_id in output_block_list:
392 output_block_list[layer_id].append(layer_name)
393 else:
394 output_block_list[layer_id] = [layer_name]
395
396 if len(output_block_list) > 1:
397 resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
398 attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
399
400 resnet_0_paths = renew_resnet_paths(resnets)
401 paths = renew_resnet_paths(resnets)
402
403 meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
404 assign_to_checkpoint(
405 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
406 )
407
408 if ["conv.weight", "conv.bias"] in output_block_list.values():
409 index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
410 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
411 f"output_blocks.{i}.{index}.conv.weight"
412 ]
413 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
414 f"output_blocks.{i}.{index}.conv.bias"
415 ]
416
417 # Clear attentions as they have been attributed above.
418 if len(attentions) == 2:
419 attentions = []
420
421 if len(attentions):
422 paths = renew_attention_paths(attentions)
423 meta_path = {
424 "old": f"output_blocks.{i}.1",
425 "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
426 }
427 assign_to_checkpoint(
428 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
429 )
430 else:
431 resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
432 for path in resnet_0_paths:
433 old_path = ".".join(["output_blocks", str(i), path["old"]])
434 new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
435
436 new_checkpoint[new_path] = unet_state_dict[old_path]
437
438 return new_checkpoint
439
440
441def convert_ldm_vae_checkpoint(checkpoint, config):
442 # extract state dict for VAE
443 vae_state_dict = {}
444 vae_key = "first_stage_model."
445 keys = list(checkpoint.keys())
446 for key in keys:
447 if key.startswith(vae_key):
448 vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
449
450 new_checkpoint = {}
451
452 new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
453 new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
454 new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
455 new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
456 new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
457 new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
458
459 new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
460 new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
461 new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
462 new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
463 new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
464 new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
465
466 new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
467 new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
468 new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
469 new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
470
471 # Retrieves the keys for the encoder down blocks only
472 num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
473 down_blocks = {
474 layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
475 }
476
477 # Retrieves the keys for the decoder up blocks only
478 num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
479 up_blocks = {
480 layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
481 }
482
483 for i in range(num_down_blocks):
484 resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
485
486 if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
487 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
488 f"encoder.down.{i}.downsample.conv.weight"
489 )
490 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
491 f"encoder.down.{i}.downsample.conv.bias"
492 )
493
494 paths = renew_vae_resnet_paths(resnets)
495 meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
496 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
497
498 mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
499 num_mid_res_blocks = 2
500 for i in range(1, num_mid_res_blocks + 1):
501 resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
502
503 paths = renew_vae_resnet_paths(resnets)
504 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
505 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
506
507 mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
508 paths = renew_vae_attention_paths(mid_attentions)
509 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
510 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
511 conv_attn_to_linear(new_checkpoint)
512
513 for i in range(num_up_blocks):
514 block_id = num_up_blocks - 1 - i
515 resnets = [
516 key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
517 ]
518
519 if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
520 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
521 f"decoder.up.{block_id}.upsample.conv.weight"
522 ]
523 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
524 f"decoder.up.{block_id}.upsample.conv.bias"
525 ]
526
527 paths = renew_vae_resnet_paths(resnets)
528 meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
529 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
530
531 mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
532 num_mid_res_blocks = 2
533 for i in range(1, num_mid_res_blocks + 1):
534 resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
535
536 paths = renew_vae_resnet_paths(resnets)
537 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
538 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
539
540 mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
541 paths = renew_vae_attention_paths(mid_attentions)
542 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
543 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
544 conv_attn_to_linear(new_checkpoint)
545 return new_checkpoint
546
547
548def convert_ldm_bert_checkpoint(checkpoint, config):
549 def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
550 hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
551 hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
552 hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
553
554 hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
555 hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
556
557 def _copy_linear(hf_linear, pt_linear):
558 hf_linear.weight = pt_linear.weight
559 hf_linear.bias = pt_linear.bias
560
561 def _copy_layer(hf_layer, pt_layer):
562 # copy layer norms
563 _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
564 _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
565
566 # copy attn
567 _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
568
569 # copy MLP
570 pt_mlp = pt_layer[1][1]
571 _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
572 _copy_linear(hf_layer.fc2, pt_mlp.net[2])
573
574 def _copy_layers(hf_layers, pt_layers):
575 for i, hf_layer in enumerate(hf_layers):
576 if i != 0:
577 i += i
578 pt_layer = pt_layers[i : i + 2]
579 _copy_layer(hf_layer, pt_layer)
580
581 hf_model = LDMBertModel(config).eval()
582
583 # copy embeds
584 hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
585 hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
586
587 # copy layer norm
588 _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
589
590 # copy hidden layers
591 _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
592
593 _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
594
595 return hf_model
596
597
598if __name__ == "__main__":
599 parser = argparse.ArgumentParser()
600
601 parser.add_argument(
602 "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
603 )
604 # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
605 parser.add_argument(
606 "--original_config_file",
607 default=None,
608 type=str,
609 help="The YAML config file corresponding to the original architecture.",
610 )
611 parser.add_argument(
612 "--scheduler_type",
613 default="pndm",
614 type=str,
615 help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
616 )
617 parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
618
619 args = parser.parse_args()
620
621 if args.original_config_file is None:
622 os.system(
623 "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
624 )
625 args.original_config_file = "./v1-inference.yaml"
626
627 original_config = OmegaConf.load(args.original_config_file)
628 checkpoint = torch.load(args.checkpoint_path)["state_dict"]
629
630 num_train_timesteps = original_config.model.params.timesteps
631 beta_start = original_config.model.params.linear_start
632 beta_end = original_config.model.params.linear_end
633 if args.scheduler_type == "pndm":
634 scheduler = PNDMScheduler(
635 beta_end=beta_end,
636 beta_schedule="scaled_linear",
637 beta_start=beta_start,
638 num_train_timesteps=num_train_timesteps,
639 skip_prk_steps=True,
640 )
641 elif args.scheduler_type == "lms":
642 scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
643 elif args.scheduler_type == "ddim":
644 scheduler = DDIMScheduler(
645 beta_start=beta_start,
646 beta_end=beta_end,
647 beta_schedule="scaled_linear",
648 clip_sample=False,
649 set_alpha_to_one=False,
650 )
651 else:
652 raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
653
654 # Convert the UNet2DConditionModel model.
655 unet_config = create_unet_diffusers_config(original_config)
656 converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
657
658 unet = UNet2DConditionModel(**unet_config)
659 unet.load_state_dict(converted_unet_checkpoint)
660
661 # Convert the VAE model.
662 vae_config = create_vae_diffusers_config(original_config)
663 converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
664
665 vae = AutoencoderKL(**vae_config)
666 vae.load_state_dict(converted_vae_checkpoint)
667
668 # Convert the text model.
669 text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
670 if text_model_type == "FrozenCLIPEmbedder":
671 text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
672 tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
673 safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
674 feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
675 pipe = StableDiffusionPipeline(
676 vae=vae,
677 text_encoder=text_model,
678 tokenizer=tokenizer,
679 unet=unet,
680 scheduler=scheduler,
681 safety_checker=safety_checker,
682 feature_extractor=feature_extractor,
683 )
684 else:
685 text_config = create_ldm_bert_config(original_config)
686 text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
687 tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
688 pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
689
690 pipe.save_pretrained(args.dump_path)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1b8a3d2..7a33bca 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -110,7 +110,7 @@ def parse_args():
110 parser.add_argument( 110 parser.add_argument(
111 "--tag_dropout", 111 "--tag_dropout",
112 type=float, 112 type=float,
113 default=0.1, 113 default=0,
114 help="Tag dropout probability.", 114 help="Tag dropout probability.",
115 ) 115 )
116 parser.add_argument( 116 parser.add_argument(
@@ -131,6 +131,11 @@ def parse_args():
131 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 131 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
132 ) 132 )
133 parser.add_argument( 133 parser.add_argument(
134 "--guidance_scale",
135 type=float,
136 default=0,
137 )
138 parser.add_argument(
134 "--num_class_images", 139 "--num_class_images",
135 type=int, 140 type=int,
136 default=0, 141 default=0,
@@ -178,7 +183,7 @@ def parse_args():
178 parser.add_argument( 183 parser.add_argument(
179 "--offset_noise_strength", 184 "--offset_noise_strength",
180 type=float, 185 type=float,
181 default=0.15, 186 default=0,
182 help="Perlin offset noise strength.", 187 help="Perlin offset noise strength.",
183 ) 188 )
184 parser.add_argument( 189 parser.add_argument(
@@ -557,8 +562,8 @@ def main():
557 vae=vae, 562 vae=vae,
558 noise_scheduler=noise_scheduler, 563 noise_scheduler=noise_scheduler,
559 dtype=weight_dtype, 564 dtype=weight_dtype,
560 with_prior_preservation=args.num_class_images != 0, 565 guidance_scale=args.guidance_scale,
561 prior_loss_weight=args.prior_loss_weight, 566 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
562 no_val=args.valid_set_size == 0, 567 no_val=args.valid_set_size == 0,
563 ) 568 )
564 569
@@ -570,6 +575,7 @@ def main():
570 batch_size=args.train_batch_size, 575 batch_size=args.train_batch_size,
571 tokenizer=tokenizer, 576 tokenizer=tokenizer,
572 class_subdir=args.class_image_dir, 577 class_subdir=args.class_image_dir,
578 with_guidance=args.guidance_scale != 0,
573 num_class_images=args.num_class_images, 579 num_class_images=args.num_class_images,
574 size=args.resolution, 580 size=args.resolution,
575 num_buckets=args.num_buckets, 581 num_buckets=args.num_buckets,
diff --git a/train_lora.py b/train_lora.py
index b16a99b..684d0cc 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -88,7 +88,7 @@ def parse_args():
88 parser.add_argument( 88 parser.add_argument(
89 "--num_buckets", 89 "--num_buckets",
90 type=int, 90 type=int,
91 default=0, 91 default=2,
92 help="Number of aspect ratio buckets in either direction.", 92 help="Number of aspect ratio buckets in either direction.",
93 ) 93 )
94 parser.add_argument( 94 parser.add_argument(
@@ -111,7 +111,7 @@ def parse_args():
111 parser.add_argument( 111 parser.add_argument(
112 "--tag_dropout", 112 "--tag_dropout",
113 type=float, 113 type=float,
114 default=0.1, 114 default=0,
115 help="Tag dropout probability.", 115 help="Tag dropout probability.",
116 ) 116 )
117 parser.add_argument( 117 parser.add_argument(
@@ -120,6 +120,11 @@ def parse_args():
120 help="Shuffle tags.", 120 help="Shuffle tags.",
121 ) 121 )
122 parser.add_argument( 122 parser.add_argument(
123 "--guidance_scale",
124 type=float,
125 default=0,
126 )
127 parser.add_argument(
123 "--num_class_images", 128 "--num_class_images",
124 type=int, 129 type=int,
125 default=0, 130 default=0,
@@ -167,7 +172,7 @@ def parse_args():
167 parser.add_argument( 172 parser.add_argument(
168 "--offset_noise_strength", 173 "--offset_noise_strength",
169 type=float, 174 type=float,
170 default=0.15, 175 default=0,
171 help="Perlin offset noise strength.", 176 help="Perlin offset noise strength.",
172 ) 177 )
173 parser.add_argument( 178 parser.add_argument(
@@ -589,8 +594,8 @@ def main():
589 vae=vae, 594 vae=vae,
590 noise_scheduler=noise_scheduler, 595 noise_scheduler=noise_scheduler,
591 dtype=weight_dtype, 596 dtype=weight_dtype,
592 with_prior_preservation=args.num_class_images != 0, 597 guidance_scale=args.guidance_scale,
593 prior_loss_weight=args.prior_loss_weight, 598 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
594 no_val=args.valid_set_size == 0, 599 no_val=args.valid_set_size == 0,
595 ) 600 )
596 601
@@ -602,6 +607,7 @@ def main():
602 batch_size=args.train_batch_size, 607 batch_size=args.train_batch_size,
603 tokenizer=tokenizer, 608 tokenizer=tokenizer,
604 class_subdir=args.class_image_dir, 609 class_subdir=args.class_image_dir,
610 with_guidance=args.guidance_scale != 0,
605 num_class_images=args.num_class_images, 611 num_class_images=args.num_class_images,
606 size=args.resolution, 612 size=args.resolution,
607 num_buckets=args.num_buckets, 613 num_buckets=args.num_buckets,
diff --git a/train_ti.py b/train_ti.py
index bbc5524..83ad46d 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -91,6 +91,11 @@ def parse_args():
91 action="store_true", 91 action="store_true",
92 ) 92 )
93 parser.add_argument( 93 parser.add_argument(
94 "--guidance_scale",
95 type=float,
96 default=0,
97 )
98 parser.add_argument(
94 "--num_class_images", 99 "--num_class_images",
95 type=int, 100 type=int,
96 default=0, 101 default=0,
@@ -167,7 +172,7 @@ def parse_args():
167 parser.add_argument( 172 parser.add_argument(
168 "--tag_dropout", 173 "--tag_dropout",
169 type=float, 174 type=float,
170 default=0.1, 175 default=0,
171 help="Tag dropout probability.", 176 help="Tag dropout probability.",
172 ) 177 )
173 parser.add_argument( 178 parser.add_argument(
@@ -190,7 +195,7 @@ def parse_args():
190 parser.add_argument( 195 parser.add_argument(
191 "--offset_noise_strength", 196 "--offset_noise_strength",
192 type=float, 197 type=float,
193 default=0.15, 198 default=0,
194 help="Perlin offset noise strength.", 199 help="Perlin offset noise strength.",
195 ) 200 )
196 parser.add_argument( 201 parser.add_argument(
@@ -651,8 +656,8 @@ def main():
651 noise_scheduler=noise_scheduler, 656 noise_scheduler=noise_scheduler,
652 dtype=weight_dtype, 657 dtype=weight_dtype,
653 seed=args.seed, 658 seed=args.seed,
654 with_prior_preservation=args.num_class_images != 0, 659 guidance_scale=args.guidance_scale,
655 prior_loss_weight=args.prior_loss_weight, 660 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
656 no_val=args.valid_set_size == 0, 661 no_val=args.valid_set_size == 0,
657 strategy=textual_inversion_strategy, 662 strategy=textual_inversion_strategy,
658 num_train_epochs=args.num_train_epochs, 663 num_train_epochs=args.num_train_epochs,
@@ -705,6 +710,7 @@ def main():
705 batch_size=args.train_batch_size, 710 batch_size=args.train_batch_size,
706 tokenizer=tokenizer, 711 tokenizer=tokenizer,
707 class_subdir=args.class_image_dir, 712 class_subdir=args.class_image_dir,
713 with_guidance=args.guidance_scale != 0,
708 num_class_images=args.num_class_images, 714 num_class_images=args.num_class_images,
709 size=args.resolution, 715 size=args.resolution,
710 num_buckets=args.num_buckets, 716 num_buckets=args.num_buckets,
diff --git a/training/functional.py b/training/functional.py
index 87bb339..d285366 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -274,7 +274,7 @@ def loss_step(
274 noise_scheduler: SchedulerMixin, 274 noise_scheduler: SchedulerMixin,
275 unet: UNet2DConditionModel, 275 unet: UNet2DConditionModel,
276 text_encoder: CLIPTextModel, 276 text_encoder: CLIPTextModel,
277 with_prior_preservation: bool, 277 guidance_scale: float,
278 prior_loss_weight: float, 278 prior_loss_weight: float,
279 seed: int, 279 seed: int,
280 offset_noise_strength: float, 280 offset_noise_strength: float,
@@ -283,13 +283,13 @@ def loss_step(
283 eval: bool = False, 283 eval: bool = False,
284 min_snr_gamma: int = 5, 284 min_snr_gamma: int = 5,
285): 285):
286 # Convert images to latent space 286 images = batch["pixel_values"]
287 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 287 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
288 latents = latents * vae.config.scaling_factor 288 bsz = images.shape[0]
289
290 bsz = latents.shape[0]
291 289
292 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 290 # Convert images to latent space
291 latents = vae.encode(images).latent_dist.sample(generator=generator)
292 latents *= vae.config.scaling_factor
293 293
294 # Sample noise that we'll add to the latents 294 # Sample noise that we'll add to the latents
295 noise = torch.randn( 295 noise = torch.randn(
@@ -301,13 +301,13 @@ def loss_step(
301 ) 301 )
302 302
303 if offset_noise_strength != 0: 303 if offset_noise_strength != 0:
304 noise += offset_noise_strength * perlin_noise( 304 offset_noise = torch.randn(
305 latents.shape, 305 (latents.shape[0], latents.shape[1], 1, 1),
306 res=1,
307 dtype=latents.dtype, 306 dtype=latents.dtype,
308 device=latents.device, 307 device=latents.device,
309 generator=generator 308 generator=generator
310 ) 309 ).expand(noise.shape)
310 noise += offset_noise_strength * offset_noise
311 311
312 # Sample a random timestep for each image 312 # Sample a random timestep for each image
313 timesteps = torch.randint( 313 timesteps = torch.randint(
@@ -343,7 +343,13 @@ def loss_step(
343 else: 343 else:
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345 345
346 if with_prior_preservation: 346 if guidance_scale != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
348 model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0)
349 model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond)
350
351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
352 elif prior_loss_weight != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 353 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
348 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 354 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
349 target, target_prior = torch.chunk(target, 2, dim=0) 355 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -607,9 +613,9 @@ def train(
607 checkpoint_frequency: int = 50, 613 checkpoint_frequency: int = 50,
608 milestone_checkpoints: bool = True, 614 milestone_checkpoints: bool = True,
609 global_step_offset: int = 0, 615 global_step_offset: int = 0,
610 with_prior_preservation: bool = False, 616 guidance_scale: float = 0.0,
611 prior_loss_weight: float = 1.0, 617 prior_loss_weight: float = 1.0,
612 offset_noise_strength: float = 0.1, 618 offset_noise_strength: float = 0.15,
613 **kwargs, 619 **kwargs,
614): 620):
615 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 621 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -638,7 +644,7 @@ def train(
638 noise_scheduler, 644 noise_scheduler,
639 unet, 645 unet,
640 text_encoder, 646 text_encoder,
641 with_prior_preservation, 647 guidance_scale,
642 prior_loss_weight, 648 prior_loss_weight,
643 seed, 649 seed,
644 offset_noise_strength, 650 offset_noise_strength,