diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
| commit | 6b8a93f46f053668c8023520225a18445d48d8f1 (patch) | |
| tree | 463c8835a9a90dd9b5586a13e55d6882caa3103a /scripts | |
| parent | Update (diff) | |
| download | textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2 textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip | |
Update
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/convert_diffusers_to_original_stable_diffusion.py | 234 | ||||
| -rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 |
2 files changed, 0 insertions, 924 deletions
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py deleted file mode 100644 index 9888f62..0000000 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ /dev/null | |||
| @@ -1,234 +0,0 @@ | |||
| 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. | ||
| 2 | # *Only* converts the UNet, VAE, and Text Encoder. | ||
| 3 | # Does not convert optimizer state or any other thing. | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import os.path as osp | ||
| 7 | |||
| 8 | import torch | ||
| 9 | |||
| 10 | |||
| 11 | # =================# | ||
| 12 | # UNet Conversion # | ||
| 13 | # =================# | ||
| 14 | |||
| 15 | unet_conversion_map = [ | ||
| 16 | # (stable-diffusion, HF Diffusers) | ||
| 17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), | ||
| 18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), | ||
| 19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), | ||
| 20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), | ||
| 21 | ("input_blocks.0.0.weight", "conv_in.weight"), | ||
| 22 | ("input_blocks.0.0.bias", "conv_in.bias"), | ||
| 23 | ("out.0.weight", "conv_norm_out.weight"), | ||
| 24 | ("out.0.bias", "conv_norm_out.bias"), | ||
| 25 | ("out.2.weight", "conv_out.weight"), | ||
| 26 | ("out.2.bias", "conv_out.bias"), | ||
| 27 | ] | ||
| 28 | |||
| 29 | unet_conversion_map_resnet = [ | ||
| 30 | # (stable-diffusion, HF Diffusers) | ||
| 31 | ("in_layers.0", "norm1"), | ||
| 32 | ("in_layers.2", "conv1"), | ||
| 33 | ("out_layers.0", "norm2"), | ||
| 34 | ("out_layers.3", "conv2"), | ||
| 35 | ("emb_layers.1", "time_emb_proj"), | ||
| 36 | ("skip_connection", "conv_shortcut"), | ||
| 37 | ] | ||
| 38 | |||
| 39 | unet_conversion_map_layer = [] | ||
| 40 | # hardcoded number of downblocks and resnets/attentions... | ||
| 41 | # would need smarter logic for other networks. | ||
| 42 | for i in range(4): | ||
| 43 | # loop over downblocks/upblocks | ||
| 44 | |||
| 45 | for j in range(2): | ||
| 46 | # loop over resnets/attentions for downblocks | ||
| 47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | ||
| 48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | ||
| 49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | ||
| 50 | |||
| 51 | if i < 3: | ||
| 52 | # no attention layers in down_blocks.3 | ||
| 53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | ||
| 54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | ||
| 55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | ||
| 56 | |||
| 57 | for j in range(3): | ||
| 58 | # loop over resnets/attentions for upblocks | ||
| 59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | ||
| 60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | ||
| 61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | ||
| 62 | |||
| 63 | if i > 0: | ||
| 64 | # no attention layers in up_blocks.0 | ||
| 65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | ||
| 66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | ||
| 67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | ||
| 68 | |||
| 69 | if i < 3: | ||
| 70 | # no downsample in down_blocks.3 | ||
| 71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | ||
| 72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | ||
| 73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 74 | |||
| 75 | # no upsample in up_blocks.3 | ||
| 76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." | ||
| 78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 79 | |||
| 80 | hf_mid_atn_prefix = "mid_block.attentions.0." | ||
| 81 | sd_mid_atn_prefix = "middle_block.1." | ||
| 82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | ||
| 83 | |||
| 84 | for j in range(2): | ||
| 85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." | ||
| 86 | sd_mid_res_prefix = f"middle_block.{2*j}." | ||
| 87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 88 | |||
| 89 | |||
| 90 | def convert_unet_state_dict(unet_state_dict): | ||
| 91 | # buyer beware: this is a *brittle* function, | ||
| 92 | # and correct output requires that all of these pieces interact in | ||
| 93 | # the exact order in which I have arranged them. | ||
| 94 | mapping = {k: k for k in unet_state_dict.keys()} | ||
| 95 | for sd_name, hf_name in unet_conversion_map: | ||
| 96 | mapping[hf_name] = sd_name | ||
| 97 | for k, v in mapping.items(): | ||
| 98 | if "resnets" in k: | ||
| 99 | for sd_part, hf_part in unet_conversion_map_resnet: | ||
| 100 | v = v.replace(hf_part, sd_part) | ||
| 101 | mapping[k] = v | ||
| 102 | for k, v in mapping.items(): | ||
| 103 | for sd_part, hf_part in unet_conversion_map_layer: | ||
| 104 | v = v.replace(hf_part, sd_part) | ||
| 105 | mapping[k] = v | ||
| 106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} | ||
| 107 | return new_state_dict | ||
| 108 | |||
| 109 | |||
| 110 | # ================# | ||
| 111 | # VAE Conversion # | ||
| 112 | # ================# | ||
| 113 | |||
| 114 | vae_conversion_map = [ | ||
| 115 | # (stable-diffusion, HF Diffusers) | ||
| 116 | ("nin_shortcut", "conv_shortcut"), | ||
| 117 | ("norm_out", "conv_norm_out"), | ||
| 118 | ("mid.attn_1.", "mid_block.attentions.0."), | ||
| 119 | ] | ||
| 120 | |||
| 121 | for i in range(4): | ||
| 122 | # down_blocks have two resnets | ||
| 123 | for j in range(2): | ||
| 124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." | ||
| 125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." | ||
| 126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) | ||
| 127 | |||
| 128 | if i < 3: | ||
| 129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." | ||
| 130 | sd_downsample_prefix = f"down.{i}.downsample." | ||
| 131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 132 | |||
| 133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 134 | sd_upsample_prefix = f"up.{3-i}.upsample." | ||
| 135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 136 | |||
| 137 | # up_blocks have three resnets | ||
| 138 | # also, up blocks in hf are numbered in reverse from sd | ||
| 139 | for j in range(3): | ||
| 140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." | ||
| 141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." | ||
| 142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) | ||
| 143 | |||
| 144 | # this part accounts for mid blocks in both the encoder and the decoder | ||
| 145 | for i in range(2): | ||
| 146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." | ||
| 147 | sd_mid_res_prefix = f"mid.block_{i+1}." | ||
| 148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 149 | |||
| 150 | |||
| 151 | vae_conversion_map_attn = [ | ||
| 152 | # (stable-diffusion, HF Diffusers) | ||
| 153 | ("norm.", "group_norm."), | ||
| 154 | ("q.", "query."), | ||
| 155 | ("k.", "key."), | ||
| 156 | ("v.", "value."), | ||
| 157 | ("proj_out.", "proj_attn."), | ||
| 158 | ] | ||
| 159 | |||
| 160 | |||
| 161 | def reshape_weight_for_sd(w): | ||
| 162 | # convert HF linear weights to SD conv2d weights | ||
| 163 | return w.reshape(*w.shape, 1, 1) | ||
| 164 | |||
| 165 | |||
| 166 | def convert_vae_state_dict(vae_state_dict): | ||
| 167 | mapping = {k: k for k in vae_state_dict.keys()} | ||
| 168 | for k, v in mapping.items(): | ||
| 169 | for sd_part, hf_part in vae_conversion_map: | ||
| 170 | v = v.replace(hf_part, sd_part) | ||
| 171 | mapping[k] = v | ||
| 172 | for k, v in mapping.items(): | ||
| 173 | if "attentions" in k: | ||
| 174 | for sd_part, hf_part in vae_conversion_map_attn: | ||
| 175 | v = v.replace(hf_part, sd_part) | ||
| 176 | mapping[k] = v | ||
| 177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} | ||
| 178 | weights_to_convert = ["q", "k", "v", "proj_out"] | ||
| 179 | for k, v in new_state_dict.items(): | ||
| 180 | for weight_name in weights_to_convert: | ||
| 181 | if f"mid.attn_1.{weight_name}.weight" in k: | ||
| 182 | print(f"Reshaping {k} for SD format") | ||
| 183 | new_state_dict[k] = reshape_weight_for_sd(v) | ||
| 184 | return new_state_dict | ||
| 185 | |||
| 186 | |||
| 187 | # =========================# | ||
| 188 | # Text Encoder Conversion # | ||
| 189 | # =========================# | ||
| 190 | # pretty much a no-op | ||
| 191 | |||
| 192 | |||
| 193 | def convert_text_enc_state_dict(text_enc_dict): | ||
| 194 | return text_enc_dict | ||
| 195 | |||
| 196 | |||
| 197 | if __name__ == "__main__": | ||
| 198 | parser = argparse.ArgumentParser() | ||
| 199 | |||
| 200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") | ||
| 201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") | ||
| 203 | |||
| 204 | args = parser.parse_args() | ||
| 205 | |||
| 206 | assert args.model_path is not None, "Must provide a model path!" | ||
| 207 | |||
| 208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" | ||
| 209 | |||
| 210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") | ||
| 211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") | ||
| 212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") | ||
| 213 | |||
| 214 | # Convert the UNet model | ||
| 215 | unet_state_dict = torch.load(unet_path, map_location="cpu") | ||
| 216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) | ||
| 217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} | ||
| 218 | |||
| 219 | # Convert the VAE model | ||
| 220 | vae_state_dict = torch.load(vae_path, map_location="cpu") | ||
| 221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) | ||
| 222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | ||
| 223 | |||
| 224 | # Convert the text encoder model | ||
| 225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") | ||
| 226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | ||
| 227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} | ||
| 228 | |||
| 229 | # Put together new checkpoint | ||
| 230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | ||
| 231 | if args.half: | ||
| 232 | state_dict = {k: v.half() for k, v in state_dict.items()} | ||
| 233 | state_dict = {"state_dict": state_dict} | ||
| 234 | torch.save(state_dict, args.checkpoint_path) | ||
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py deleted file mode 100644 index ee7fc33..0000000 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ /dev/null | |||
| @@ -1,690 +0,0 @@ | |||
| 1 | # coding=utf-8 | ||
| 2 | # Copyright 2022 The HuggingFace Inc. team. | ||
| 3 | # | ||
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | # you may not use this file except in compliance with the License. | ||
| 6 | # You may obtain a copy of the License at | ||
| 7 | # | ||
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | # | ||
| 10 | # Unless required by applicable law or agreed to in writing, software | ||
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | # See the License for the specific language governing permissions and | ||
| 14 | # limitations under the License. | ||
| 15 | """ Conversion script for the LDM checkpoints. """ | ||
| 16 | |||
| 17 | import argparse | ||
| 18 | import os | ||
| 19 | |||
| 20 | import torch | ||
| 21 | |||
| 22 | |||
| 23 | try: | ||
| 24 | from omegaconf import OmegaConf | ||
| 25 | except ImportError: | ||
| 26 | raise ImportError( | ||
| 27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
| 28 | ) | ||
| 29 | |||
| 30 | from diffusers import ( | ||
| 31 | AutoencoderKL, | ||
| 32 | DDIMScheduler, | ||
| 33 | LDMTextToImagePipeline, | ||
| 34 | LMSDiscreteScheduler, | ||
| 35 | PNDMScheduler, | ||
| 36 | StableDiffusionPipeline, | ||
| 37 | UNet2DConditionModel, | ||
| 38 | ) | ||
| 39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
| 40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
| 41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
| 42 | |||
| 43 | |||
| 44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
| 45 | """ | ||
| 46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
| 47 | """ | ||
| 48 | if n_shave_prefix_segments >= 0: | ||
| 49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
| 50 | else: | ||
| 51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
| 52 | |||
| 53 | |||
| 54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 55 | """ | ||
| 56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 57 | """ | ||
| 58 | mapping = [] | ||
| 59 | for old_item in old_list: | ||
| 60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
| 61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
| 62 | |||
| 63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
| 64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
| 65 | |||
| 66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
| 67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
| 68 | |||
| 69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 70 | |||
| 71 | mapping.append({"old": old_item, "new": new_item}) | ||
| 72 | |||
| 73 | return mapping | ||
| 74 | |||
| 75 | |||
| 76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 77 | """ | ||
| 78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 79 | """ | ||
| 80 | mapping = [] | ||
| 81 | for old_item in old_list: | ||
| 82 | new_item = old_item | ||
| 83 | |||
| 84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
| 85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 86 | |||
| 87 | mapping.append({"old": old_item, "new": new_item}) | ||
| 88 | |||
| 89 | return mapping | ||
| 90 | |||
| 91 | |||
| 92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 93 | """ | ||
| 94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 95 | """ | ||
| 96 | mapping = [] | ||
| 97 | for old_item in old_list: | ||
| 98 | new_item = old_item | ||
| 99 | |||
| 100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
| 101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
| 102 | |||
| 103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
| 104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
| 105 | |||
| 106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 107 | |||
| 108 | mapping.append({"old": old_item, "new": new_item}) | ||
| 109 | |||
| 110 | return mapping | ||
| 111 | |||
| 112 | |||
| 113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 114 | """ | ||
| 115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 116 | """ | ||
| 117 | mapping = [] | ||
| 118 | for old_item in old_list: | ||
| 119 | new_item = old_item | ||
| 120 | |||
| 121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
| 122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
| 123 | |||
| 124 | new_item = new_item.replace("q.weight", "query.weight") | ||
| 125 | new_item = new_item.replace("q.bias", "query.bias") | ||
| 126 | |||
| 127 | new_item = new_item.replace("k.weight", "key.weight") | ||
| 128 | new_item = new_item.replace("k.bias", "key.bias") | ||
| 129 | |||
| 130 | new_item = new_item.replace("v.weight", "value.weight") | ||
| 131 | new_item = new_item.replace("v.bias", "value.bias") | ||
| 132 | |||
| 133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
| 134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
| 135 | |||
| 136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 137 | |||
| 138 | mapping.append({"old": old_item, "new": new_item}) | ||
| 139 | |||
| 140 | return mapping | ||
| 141 | |||
| 142 | |||
| 143 | def assign_to_checkpoint( | ||
| 144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
| 145 | ): | ||
| 146 | """ | ||
| 147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
| 148 | to them. It splits attention layers, and takes into account additional replacements | ||
| 149 | that may arise. | ||
| 150 | |||
| 151 | Assigns the weights to the new checkpoint. | ||
| 152 | """ | ||
| 153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
| 154 | |||
| 155 | # Splits the attention layers into three variables. | ||
| 156 | if attention_paths_to_split is not None: | ||
| 157 | for path, path_map in attention_paths_to_split.items(): | ||
| 158 | old_tensor = old_checkpoint[path] | ||
| 159 | channels = old_tensor.shape[0] // 3 | ||
| 160 | |||
| 161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
| 162 | |||
| 163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
| 164 | |||
| 165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
| 166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
| 167 | |||
| 168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
| 169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
| 170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
| 171 | |||
| 172 | for path in paths: | ||
| 173 | new_path = path["new"] | ||
| 174 | |||
| 175 | # These have already been assigned | ||
| 176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
| 177 | continue | ||
| 178 | |||
| 179 | # Global renaming happens here | ||
| 180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
| 181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
| 182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
| 183 | |||
| 184 | if additional_replacements is not None: | ||
| 185 | for replacement in additional_replacements: | ||
| 186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
| 187 | |||
| 188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
| 189 | if "proj_attn.weight" in new_path: | ||
| 190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
| 191 | else: | ||
| 192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
| 193 | |||
| 194 | |||
| 195 | def conv_attn_to_linear(checkpoint): | ||
| 196 | keys = list(checkpoint.keys()) | ||
| 197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
| 198 | for key in keys: | ||
| 199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
| 200 | if checkpoint[key].ndim > 2: | ||
| 201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
| 202 | elif "proj_attn.weight" in key: | ||
| 203 | if checkpoint[key].ndim > 2: | ||
| 204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
| 205 | |||
| 206 | |||
| 207 | def create_unet_diffusers_config(original_config): | ||
| 208 | """ | ||
| 209 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 210 | """ | ||
| 211 | unet_params = original_config.model.params.unet_config.params | ||
| 212 | |||
| 213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
| 214 | |||
| 215 | down_block_types = [] | ||
| 216 | resolution = 1 | ||
| 217 | for i in range(len(block_out_channels)): | ||
| 218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
| 219 | down_block_types.append(block_type) | ||
| 220 | if i != len(block_out_channels) - 1: | ||
| 221 | resolution *= 2 | ||
| 222 | |||
| 223 | up_block_types = [] | ||
| 224 | for i in range(len(block_out_channels)): | ||
| 225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
| 226 | up_block_types.append(block_type) | ||
| 227 | resolution //= 2 | ||
| 228 | |||
| 229 | config = dict( | ||
| 230 | sample_size=unet_params.image_size, | ||
| 231 | in_channels=unet_params.in_channels, | ||
| 232 | out_channels=unet_params.out_channels, | ||
| 233 | down_block_types=tuple(down_block_types), | ||
| 234 | up_block_types=tuple(up_block_types), | ||
| 235 | block_out_channels=tuple(block_out_channels), | ||
| 236 | layers_per_block=unet_params.num_res_blocks, | ||
| 237 | cross_attention_dim=unet_params.context_dim, | ||
| 238 | attention_head_dim=unet_params.num_heads, | ||
| 239 | ) | ||
| 240 | |||
| 241 | return config | ||
| 242 | |||
| 243 | |||
| 244 | def create_vae_diffusers_config(original_config): | ||
| 245 | """ | ||
| 246 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 247 | """ | ||
| 248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
| 249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
| 250 | |||
| 251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
| 252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
| 253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
| 254 | |||
| 255 | config = dict( | ||
| 256 | sample_size=vae_params.resolution, | ||
| 257 | in_channels=vae_params.in_channels, | ||
| 258 | out_channels=vae_params.out_ch, | ||
| 259 | down_block_types=tuple(down_block_types), | ||
| 260 | up_block_types=tuple(up_block_types), | ||
| 261 | block_out_channels=tuple(block_out_channels), | ||
| 262 | latent_channels=vae_params.z_channels, | ||
| 263 | layers_per_block=vae_params.num_res_blocks, | ||
| 264 | ) | ||
| 265 | return config | ||
| 266 | |||
| 267 | |||
| 268 | def create_diffusers_schedular(original_config): | ||
| 269 | schedular = DDIMScheduler( | ||
| 270 | num_train_timesteps=original_config.model.params.timesteps, | ||
| 271 | beta_start=original_config.model.params.linear_start, | ||
| 272 | beta_end=original_config.model.params.linear_end, | ||
| 273 | beta_schedule="scaled_linear", | ||
| 274 | ) | ||
| 275 | return schedular | ||
| 276 | |||
| 277 | |||
| 278 | def create_ldm_bert_config(original_config): | ||
| 279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
| 280 | config = LDMBertConfig( | ||
| 281 | d_model=bert_params.n_embed, | ||
| 282 | encoder_layers=bert_params.n_layer, | ||
| 283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
| 284 | ) | ||
| 285 | return config | ||
| 286 | |||
| 287 | |||
| 288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
| 289 | """ | ||
| 290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
| 291 | """ | ||
| 292 | |||
| 293 | # extract state_dict for UNet | ||
| 294 | unet_state_dict = {} | ||
| 295 | unet_key = "model.diffusion_model." | ||
| 296 | keys = list(checkpoint.keys()) | ||
| 297 | for key in keys: | ||
| 298 | if key.startswith(unet_key): | ||
| 299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
| 300 | |||
| 301 | new_checkpoint = {} | ||
| 302 | |||
| 303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
| 304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
| 305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
| 306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
| 307 | |||
| 308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
| 309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
| 310 | |||
| 311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
| 312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
| 313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
| 314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
| 315 | |||
| 316 | # Retrieves the keys for the input blocks only | ||
| 317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
| 318 | input_blocks = { | ||
| 319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
| 320 | for layer_id in range(num_input_blocks) | ||
| 321 | } | ||
| 322 | |||
| 323 | # Retrieves the keys for the middle blocks only | ||
| 324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
| 325 | middle_blocks = { | ||
| 326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
| 327 | for layer_id in range(num_middle_blocks) | ||
| 328 | } | ||
| 329 | |||
| 330 | # Retrieves the keys for the output blocks only | ||
| 331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
| 332 | output_blocks = { | ||
| 333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
| 334 | for layer_id in range(num_output_blocks) | ||
| 335 | } | ||
| 336 | |||
| 337 | for i in range(1, num_input_blocks): | ||
| 338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
| 339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
| 340 | |||
| 341 | resnets = [ | ||
| 342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
| 343 | ] | ||
| 344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
| 345 | |||
| 346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
| 347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
| 348 | f"input_blocks.{i}.0.op.weight" | ||
| 349 | ) | ||
| 350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
| 351 | f"input_blocks.{i}.0.op.bias" | ||
| 352 | ) | ||
| 353 | |||
| 354 | paths = renew_resnet_paths(resnets) | ||
| 355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 356 | assign_to_checkpoint( | ||
| 357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 358 | ) | ||
| 359 | |||
| 360 | if len(attentions): | ||
| 361 | paths = renew_attention_paths(attentions) | ||
| 362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
| 363 | assign_to_checkpoint( | ||
| 364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 365 | ) | ||
| 366 | |||
| 367 | resnet_0 = middle_blocks[0] | ||
| 368 | attentions = middle_blocks[1] | ||
| 369 | resnet_1 = middle_blocks[2] | ||
| 370 | |||
| 371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
| 372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 373 | |||
| 374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
| 375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 376 | |||
| 377 | attentions_paths = renew_attention_paths(attentions) | ||
| 378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
| 379 | assign_to_checkpoint( | ||
| 380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 381 | ) | ||
| 382 | |||
| 383 | for i in range(num_output_blocks): | ||
| 384 | block_id = i // (config["layers_per_block"] + 1) | ||
| 385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
| 386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
| 387 | output_block_list = {} | ||
| 388 | |||
| 389 | for layer in output_block_layers: | ||
| 390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
| 391 | if layer_id in output_block_list: | ||
| 392 | output_block_list[layer_id].append(layer_name) | ||
| 393 | else: | ||
| 394 | output_block_list[layer_id] = [layer_name] | ||
| 395 | |||
| 396 | if len(output_block_list) > 1: | ||
| 397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
| 398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
| 399 | |||
| 400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
| 401 | paths = renew_resnet_paths(resnets) | ||
| 402 | |||
| 403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 404 | assign_to_checkpoint( | ||
| 405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 406 | ) | ||
| 407 | |||
| 408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
| 409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
| 410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
| 411 | f"output_blocks.{i}.{index}.conv.weight" | ||
| 412 | ] | ||
| 413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
| 414 | f"output_blocks.{i}.{index}.conv.bias" | ||
| 415 | ] | ||
| 416 | |||
| 417 | # Clear attentions as they have been attributed above. | ||
| 418 | if len(attentions) == 2: | ||
| 419 | attentions = [] | ||
| 420 | |||
| 421 | if len(attentions): | ||
| 422 | paths = renew_attention_paths(attentions) | ||
| 423 | meta_path = { | ||
| 424 | "old": f"output_blocks.{i}.1", | ||
| 425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
| 426 | } | ||
| 427 | assign_to_checkpoint( | ||
| 428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 429 | ) | ||
| 430 | else: | ||
| 431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
| 432 | for path in resnet_0_paths: | ||
| 433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
| 434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
| 435 | |||
| 436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
| 437 | |||
| 438 | return new_checkpoint | ||
| 439 | |||
| 440 | |||
| 441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
| 442 | # extract state dict for VAE | ||
| 443 | vae_state_dict = {} | ||
| 444 | vae_key = "first_stage_model." | ||
| 445 | keys = list(checkpoint.keys()) | ||
| 446 | for key in keys: | ||
| 447 | if key.startswith(vae_key): | ||
| 448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
| 449 | |||
| 450 | new_checkpoint = {} | ||
| 451 | |||
| 452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
| 453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
| 454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
| 455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
| 456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
| 457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
| 458 | |||
| 459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
| 460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
| 461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
| 462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
| 463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
| 464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
| 465 | |||
| 466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
| 467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
| 468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
| 469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
| 470 | |||
| 471 | # Retrieves the keys for the encoder down blocks only | ||
| 472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
| 473 | down_blocks = { | ||
| 474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
| 475 | } | ||
| 476 | |||
| 477 | # Retrieves the keys for the decoder up blocks only | ||
| 478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
| 479 | up_blocks = { | ||
| 480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
| 481 | } | ||
| 482 | |||
| 483 | for i in range(num_down_blocks): | ||
| 484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
| 485 | |||
| 486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
| 487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
| 488 | f"encoder.down.{i}.downsample.conv.weight" | ||
| 489 | ) | ||
| 490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
| 491 | f"encoder.down.{i}.downsample.conv.bias" | ||
| 492 | ) | ||
| 493 | |||
| 494 | paths = renew_vae_resnet_paths(resnets) | ||
| 495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
| 496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 497 | |||
| 498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
| 499 | num_mid_res_blocks = 2 | ||
| 500 | for i in range(1, num_mid_res_blocks + 1): | ||
| 501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
| 502 | |||
| 503 | paths = renew_vae_resnet_paths(resnets) | ||
| 504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 506 | |||
| 507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
| 508 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 511 | conv_attn_to_linear(new_checkpoint) | ||
| 512 | |||
| 513 | for i in range(num_up_blocks): | ||
| 514 | block_id = num_up_blocks - 1 - i | ||
| 515 | resnets = [ | ||
| 516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
| 517 | ] | ||
| 518 | |||
| 519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
| 520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
| 521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
| 522 | ] | ||
| 523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
| 524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
| 525 | ] | ||
| 526 | |||
| 527 | paths = renew_vae_resnet_paths(resnets) | ||
| 528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
| 529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 530 | |||
| 531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
| 532 | num_mid_res_blocks = 2 | ||
| 533 | for i in range(1, num_mid_res_blocks + 1): | ||
| 534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
| 535 | |||
| 536 | paths = renew_vae_resnet_paths(resnets) | ||
| 537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 539 | |||
| 540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
| 541 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 544 | conv_attn_to_linear(new_checkpoint) | ||
| 545 | return new_checkpoint | ||
| 546 | |||
| 547 | |||
| 548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
| 549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
| 550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
| 551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
| 552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
| 553 | |||
| 554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
| 555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
| 556 | |||
| 557 | def _copy_linear(hf_linear, pt_linear): | ||
| 558 | hf_linear.weight = pt_linear.weight | ||
| 559 | hf_linear.bias = pt_linear.bias | ||
| 560 | |||
| 561 | def _copy_layer(hf_layer, pt_layer): | ||
| 562 | # copy layer norms | ||
| 563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
| 564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
| 565 | |||
| 566 | # copy attn | ||
| 567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
| 568 | |||
| 569 | # copy MLP | ||
| 570 | pt_mlp = pt_layer[1][1] | ||
| 571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
| 572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
| 573 | |||
| 574 | def _copy_layers(hf_layers, pt_layers): | ||
| 575 | for i, hf_layer in enumerate(hf_layers): | ||
| 576 | if i != 0: | ||
| 577 | i += i | ||
| 578 | pt_layer = pt_layers[i : i + 2] | ||
| 579 | _copy_layer(hf_layer, pt_layer) | ||
| 580 | |||
| 581 | hf_model = LDMBertModel(config).eval() | ||
| 582 | |||
| 583 | # copy embeds | ||
| 584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
| 585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
| 586 | |||
| 587 | # copy layer norm | ||
| 588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
| 589 | |||
| 590 | # copy hidden layers | ||
| 591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
| 592 | |||
| 593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
| 594 | |||
| 595 | return hf_model | ||
| 596 | |||
| 597 | |||
| 598 | if __name__ == "__main__": | ||
| 599 | parser = argparse.ArgumentParser() | ||
| 600 | |||
| 601 | parser.add_argument( | ||
| 602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
| 603 | ) | ||
| 604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
| 605 | parser.add_argument( | ||
| 606 | "--original_config_file", | ||
| 607 | default=None, | ||
| 608 | type=str, | ||
| 609 | help="The YAML config file corresponding to the original architecture.", | ||
| 610 | ) | ||
| 611 | parser.add_argument( | ||
| 612 | "--scheduler_type", | ||
| 613 | default="pndm", | ||
| 614 | type=str, | ||
| 615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
| 616 | ) | ||
| 617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 618 | |||
| 619 | args = parser.parse_args() | ||
| 620 | |||
| 621 | if args.original_config_file is None: | ||
| 622 | os.system( | ||
| 623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
| 624 | ) | ||
| 625 | args.original_config_file = "./v1-inference.yaml" | ||
| 626 | |||
| 627 | original_config = OmegaConf.load(args.original_config_file) | ||
| 628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
| 629 | |||
| 630 | num_train_timesteps = original_config.model.params.timesteps | ||
| 631 | beta_start = original_config.model.params.linear_start | ||
| 632 | beta_end = original_config.model.params.linear_end | ||
| 633 | if args.scheduler_type == "pndm": | ||
| 634 | scheduler = PNDMScheduler( | ||
| 635 | beta_end=beta_end, | ||
| 636 | beta_schedule="scaled_linear", | ||
| 637 | beta_start=beta_start, | ||
| 638 | num_train_timesteps=num_train_timesteps, | ||
| 639 | skip_prk_steps=True, | ||
| 640 | ) | ||
| 641 | elif args.scheduler_type == "lms": | ||
| 642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
| 643 | elif args.scheduler_type == "ddim": | ||
| 644 | scheduler = DDIMScheduler( | ||
| 645 | beta_start=beta_start, | ||
| 646 | beta_end=beta_end, | ||
| 647 | beta_schedule="scaled_linear", | ||
| 648 | clip_sample=False, | ||
| 649 | set_alpha_to_one=False, | ||
| 650 | ) | ||
| 651 | else: | ||
| 652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
| 653 | |||
| 654 | # Convert the UNet2DConditionModel model. | ||
| 655 | unet_config = create_unet_diffusers_config(original_config) | ||
| 656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
| 657 | |||
| 658 | unet = UNet2DConditionModel(**unet_config) | ||
| 659 | unet.load_state_dict(converted_unet_checkpoint) | ||
| 660 | |||
| 661 | # Convert the VAE model. | ||
| 662 | vae_config = create_vae_diffusers_config(original_config) | ||
| 663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
| 664 | |||
| 665 | vae = AutoencoderKL(**vae_config) | ||
| 666 | vae.load_state_dict(converted_vae_checkpoint) | ||
| 667 | |||
| 668 | # Convert the text model. | ||
| 669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
| 670 | if text_model_type == "FrozenCLIPEmbedder": | ||
| 671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
| 673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 675 | pipe = StableDiffusionPipeline( | ||
| 676 | vae=vae, | ||
| 677 | text_encoder=text_model, | ||
| 678 | tokenizer=tokenizer, | ||
| 679 | unet=unet, | ||
| 680 | scheduler=scheduler, | ||
| 681 | safety_checker=safety_checker, | ||
| 682 | feature_extractor=feature_extractor, | ||
| 683 | ) | ||
| 684 | else: | ||
| 685 | text_config = create_ldm_bert_config(original_config) | ||
| 686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
| 687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
| 688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
| 689 | |||
| 690 | pipe.save_pretrained(args.dump_path) | ||
