diff options
| -rw-r--r-- | scripts/convert_diffusers_to_original_stable_diffusion.py | 234 |
1 files changed, 234 insertions, 0 deletions
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py new file mode 100644 index 0000000..9888f62 --- /dev/null +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py | |||
| @@ -0,0 +1,234 @@ | |||
| 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. | ||
| 2 | # *Only* converts the UNet, VAE, and Text Encoder. | ||
| 3 | # Does not convert optimizer state or any other thing. | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import os.path as osp | ||
| 7 | |||
| 8 | import torch | ||
| 9 | |||
| 10 | |||
| 11 | # =================# | ||
| 12 | # UNet Conversion # | ||
| 13 | # =================# | ||
| 14 | |||
| 15 | unet_conversion_map = [ | ||
| 16 | # (stable-diffusion, HF Diffusers) | ||
| 17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), | ||
| 18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), | ||
| 19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), | ||
| 20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), | ||
| 21 | ("input_blocks.0.0.weight", "conv_in.weight"), | ||
| 22 | ("input_blocks.0.0.bias", "conv_in.bias"), | ||
| 23 | ("out.0.weight", "conv_norm_out.weight"), | ||
| 24 | ("out.0.bias", "conv_norm_out.bias"), | ||
| 25 | ("out.2.weight", "conv_out.weight"), | ||
| 26 | ("out.2.bias", "conv_out.bias"), | ||
| 27 | ] | ||
| 28 | |||
| 29 | unet_conversion_map_resnet = [ | ||
| 30 | # (stable-diffusion, HF Diffusers) | ||
| 31 | ("in_layers.0", "norm1"), | ||
| 32 | ("in_layers.2", "conv1"), | ||
| 33 | ("out_layers.0", "norm2"), | ||
| 34 | ("out_layers.3", "conv2"), | ||
| 35 | ("emb_layers.1", "time_emb_proj"), | ||
| 36 | ("skip_connection", "conv_shortcut"), | ||
| 37 | ] | ||
| 38 | |||
| 39 | unet_conversion_map_layer = [] | ||
| 40 | # hardcoded number of downblocks and resnets/attentions... | ||
| 41 | # would need smarter logic for other networks. | ||
| 42 | for i in range(4): | ||
| 43 | # loop over downblocks/upblocks | ||
| 44 | |||
| 45 | for j in range(2): | ||
| 46 | # loop over resnets/attentions for downblocks | ||
| 47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | ||
| 48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | ||
| 49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | ||
| 50 | |||
| 51 | if i < 3: | ||
| 52 | # no attention layers in down_blocks.3 | ||
| 53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | ||
| 54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | ||
| 55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | ||
| 56 | |||
| 57 | for j in range(3): | ||
| 58 | # loop over resnets/attentions for upblocks | ||
| 59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | ||
| 60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | ||
| 61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | ||
| 62 | |||
| 63 | if i > 0: | ||
| 64 | # no attention layers in up_blocks.0 | ||
| 65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | ||
| 66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | ||
| 67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | ||
| 68 | |||
| 69 | if i < 3: | ||
| 70 | # no downsample in down_blocks.3 | ||
| 71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | ||
| 72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | ||
| 73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 74 | |||
| 75 | # no upsample in up_blocks.3 | ||
| 76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." | ||
| 78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 79 | |||
| 80 | hf_mid_atn_prefix = "mid_block.attentions.0." | ||
| 81 | sd_mid_atn_prefix = "middle_block.1." | ||
| 82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | ||
| 83 | |||
| 84 | for j in range(2): | ||
| 85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." | ||
| 86 | sd_mid_res_prefix = f"middle_block.{2*j}." | ||
| 87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 88 | |||
| 89 | |||
| 90 | def convert_unet_state_dict(unet_state_dict): | ||
| 91 | # buyer beware: this is a *brittle* function, | ||
| 92 | # and correct output requires that all of these pieces interact in | ||
| 93 | # the exact order in which I have arranged them. | ||
| 94 | mapping = {k: k for k in unet_state_dict.keys()} | ||
| 95 | for sd_name, hf_name in unet_conversion_map: | ||
| 96 | mapping[hf_name] = sd_name | ||
| 97 | for k, v in mapping.items(): | ||
| 98 | if "resnets" in k: | ||
| 99 | for sd_part, hf_part in unet_conversion_map_resnet: | ||
| 100 | v = v.replace(hf_part, sd_part) | ||
| 101 | mapping[k] = v | ||
| 102 | for k, v in mapping.items(): | ||
| 103 | for sd_part, hf_part in unet_conversion_map_layer: | ||
| 104 | v = v.replace(hf_part, sd_part) | ||
| 105 | mapping[k] = v | ||
| 106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} | ||
| 107 | return new_state_dict | ||
| 108 | |||
| 109 | |||
| 110 | # ================# | ||
| 111 | # VAE Conversion # | ||
| 112 | # ================# | ||
| 113 | |||
| 114 | vae_conversion_map = [ | ||
| 115 | # (stable-diffusion, HF Diffusers) | ||
| 116 | ("nin_shortcut", "conv_shortcut"), | ||
| 117 | ("norm_out", "conv_norm_out"), | ||
| 118 | ("mid.attn_1.", "mid_block.attentions.0."), | ||
| 119 | ] | ||
| 120 | |||
| 121 | for i in range(4): | ||
| 122 | # down_blocks have two resnets | ||
| 123 | for j in range(2): | ||
| 124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." | ||
| 125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." | ||
| 126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) | ||
| 127 | |||
| 128 | if i < 3: | ||
| 129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." | ||
| 130 | sd_downsample_prefix = f"down.{i}.downsample." | ||
| 131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 132 | |||
| 133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 134 | sd_upsample_prefix = f"up.{3-i}.upsample." | ||
| 135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 136 | |||
| 137 | # up_blocks have three resnets | ||
| 138 | # also, up blocks in hf are numbered in reverse from sd | ||
| 139 | for j in range(3): | ||
| 140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." | ||
| 141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." | ||
| 142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) | ||
| 143 | |||
| 144 | # this part accounts for mid blocks in both the encoder and the decoder | ||
| 145 | for i in range(2): | ||
| 146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." | ||
| 147 | sd_mid_res_prefix = f"mid.block_{i+1}." | ||
| 148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 149 | |||
| 150 | |||
| 151 | vae_conversion_map_attn = [ | ||
| 152 | # (stable-diffusion, HF Diffusers) | ||
| 153 | ("norm.", "group_norm."), | ||
| 154 | ("q.", "query."), | ||
| 155 | ("k.", "key."), | ||
| 156 | ("v.", "value."), | ||
| 157 | ("proj_out.", "proj_attn."), | ||
| 158 | ] | ||
| 159 | |||
| 160 | |||
| 161 | def reshape_weight_for_sd(w): | ||
| 162 | # convert HF linear weights to SD conv2d weights | ||
| 163 | return w.reshape(*w.shape, 1, 1) | ||
| 164 | |||
| 165 | |||
| 166 | def convert_vae_state_dict(vae_state_dict): | ||
| 167 | mapping = {k: k for k in vae_state_dict.keys()} | ||
| 168 | for k, v in mapping.items(): | ||
| 169 | for sd_part, hf_part in vae_conversion_map: | ||
| 170 | v = v.replace(hf_part, sd_part) | ||
| 171 | mapping[k] = v | ||
| 172 | for k, v in mapping.items(): | ||
| 173 | if "attentions" in k: | ||
| 174 | for sd_part, hf_part in vae_conversion_map_attn: | ||
| 175 | v = v.replace(hf_part, sd_part) | ||
| 176 | mapping[k] = v | ||
| 177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} | ||
| 178 | weights_to_convert = ["q", "k", "v", "proj_out"] | ||
| 179 | for k, v in new_state_dict.items(): | ||
| 180 | for weight_name in weights_to_convert: | ||
| 181 | if f"mid.attn_1.{weight_name}.weight" in k: | ||
| 182 | print(f"Reshaping {k} for SD format") | ||
| 183 | new_state_dict[k] = reshape_weight_for_sd(v) | ||
| 184 | return new_state_dict | ||
| 185 | |||
| 186 | |||
| 187 | # =========================# | ||
| 188 | # Text Encoder Conversion # | ||
| 189 | # =========================# | ||
| 190 | # pretty much a no-op | ||
| 191 | |||
| 192 | |||
| 193 | def convert_text_enc_state_dict(text_enc_dict): | ||
| 194 | return text_enc_dict | ||
| 195 | |||
| 196 | |||
| 197 | if __name__ == "__main__": | ||
| 198 | parser = argparse.ArgumentParser() | ||
| 199 | |||
| 200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") | ||
| 201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") | ||
| 203 | |||
| 204 | args = parser.parse_args() | ||
| 205 | |||
| 206 | assert args.model_path is not None, "Must provide a model path!" | ||
| 207 | |||
| 208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" | ||
| 209 | |||
| 210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") | ||
| 211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") | ||
| 212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") | ||
| 213 | |||
| 214 | # Convert the UNet model | ||
| 215 | unet_state_dict = torch.load(unet_path, map_location="cpu") | ||
| 216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) | ||
| 217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} | ||
| 218 | |||
| 219 | # Convert the VAE model | ||
| 220 | vae_state_dict = torch.load(vae_path, map_location="cpu") | ||
| 221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) | ||
| 222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | ||
| 223 | |||
| 224 | # Convert the text encoder model | ||
| 225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") | ||
| 226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | ||
| 227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} | ||
| 228 | |||
| 229 | # Put together new checkpoint | ||
| 230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | ||
| 231 | if args.half: | ||
| 232 | state_dict = {k: v.half() for k, v in state_dict.items()} | ||
| 233 | state_dict = {"state_dict": state_dict} | ||
| 234 | torch.save(state_dict, args.checkpoint_path) | ||
