summaryrefslogtreecommitdiffstats
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/convert_diffusers_to_original_stable_diffusion.py234
-rw-r--r--scripts/convert_original_stable_diffusion_to_diffusers.py690
2 files changed, 0 insertions, 924 deletions
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
deleted file mode 100644
index 9888f62..0000000
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ /dev/null
@@ -1,234 +0,0 @@
1# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2# *Only* converts the UNet, VAE, and Text Encoder.
3# Does not convert optimizer state or any other thing.
4
5import argparse
6import os.path as osp
7
8import torch
9
10
11# =================#
12# UNet Conversion #
13# =================#
14
15unet_conversion_map = [
16 # (stable-diffusion, HF Diffusers)
17 ("time_embed.0.weight", "time_embedding.linear_1.weight"),
18 ("time_embed.0.bias", "time_embedding.linear_1.bias"),
19 ("time_embed.2.weight", "time_embedding.linear_2.weight"),
20 ("time_embed.2.bias", "time_embedding.linear_2.bias"),
21 ("input_blocks.0.0.weight", "conv_in.weight"),
22 ("input_blocks.0.0.bias", "conv_in.bias"),
23 ("out.0.weight", "conv_norm_out.weight"),
24 ("out.0.bias", "conv_norm_out.bias"),
25 ("out.2.weight", "conv_out.weight"),
26 ("out.2.bias", "conv_out.bias"),
27]
28
29unet_conversion_map_resnet = [
30 # (stable-diffusion, HF Diffusers)
31 ("in_layers.0", "norm1"),
32 ("in_layers.2", "conv1"),
33 ("out_layers.0", "norm2"),
34 ("out_layers.3", "conv2"),
35 ("emb_layers.1", "time_emb_proj"),
36 ("skip_connection", "conv_shortcut"),
37]
38
39unet_conversion_map_layer = []
40# hardcoded number of downblocks and resnets/attentions...
41# would need smarter logic for other networks.
42for i in range(4):
43 # loop over downblocks/upblocks
44
45 for j in range(2):
46 # loop over resnets/attentions for downblocks
47 hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48 sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49 unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50
51 if i < 3:
52 # no attention layers in down_blocks.3
53 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54 sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55 unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56
57 for j in range(3):
58 # loop over resnets/attentions for upblocks
59 hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60 sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61 unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62
63 if i > 0:
64 # no attention layers in up_blocks.0
65 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66 sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67 unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68
69 if i < 3:
70 # no downsample in down_blocks.3
71 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72 sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73 unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74
75 # no upsample in up_blocks.3
76 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77 sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78 unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79
80hf_mid_atn_prefix = "mid_block.attentions.0."
81sd_mid_atn_prefix = "middle_block.1."
82unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83
84for j in range(2):
85 hf_mid_res_prefix = f"mid_block.resnets.{j}."
86 sd_mid_res_prefix = f"middle_block.{2*j}."
87 unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88
89
90def convert_unet_state_dict(unet_state_dict):
91 # buyer beware: this is a *brittle* function,
92 # and correct output requires that all of these pieces interact in
93 # the exact order in which I have arranged them.
94 mapping = {k: k for k in unet_state_dict.keys()}
95 for sd_name, hf_name in unet_conversion_map:
96 mapping[hf_name] = sd_name
97 for k, v in mapping.items():
98 if "resnets" in k:
99 for sd_part, hf_part in unet_conversion_map_resnet:
100 v = v.replace(hf_part, sd_part)
101 mapping[k] = v
102 for k, v in mapping.items():
103 for sd_part, hf_part in unet_conversion_map_layer:
104 v = v.replace(hf_part, sd_part)
105 mapping[k] = v
106 new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107 return new_state_dict
108
109
110# ================#
111# VAE Conversion #
112# ================#
113
114vae_conversion_map = [
115 # (stable-diffusion, HF Diffusers)
116 ("nin_shortcut", "conv_shortcut"),
117 ("norm_out", "conv_norm_out"),
118 ("mid.attn_1.", "mid_block.attentions.0."),
119]
120
121for i in range(4):
122 # down_blocks have two resnets
123 for j in range(2):
124 hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125 sd_down_prefix = f"encoder.down.{i}.block.{j}."
126 vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127
128 if i < 3:
129 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130 sd_downsample_prefix = f"down.{i}.downsample."
131 vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132
133 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134 sd_upsample_prefix = f"up.{3-i}.upsample."
135 vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136
137 # up_blocks have three resnets
138 # also, up blocks in hf are numbered in reverse from sd
139 for j in range(3):
140 hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141 sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142 vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143
144# this part accounts for mid blocks in both the encoder and the decoder
145for i in range(2):
146 hf_mid_res_prefix = f"mid_block.resnets.{i}."
147 sd_mid_res_prefix = f"mid.block_{i+1}."
148 vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149
150
151vae_conversion_map_attn = [
152 # (stable-diffusion, HF Diffusers)
153 ("norm.", "group_norm."),
154 ("q.", "query."),
155 ("k.", "key."),
156 ("v.", "value."),
157 ("proj_out.", "proj_attn."),
158]
159
160
161def reshape_weight_for_sd(w):
162 # convert HF linear weights to SD conv2d weights
163 return w.reshape(*w.shape, 1, 1)
164
165
166def convert_vae_state_dict(vae_state_dict):
167 mapping = {k: k for k in vae_state_dict.keys()}
168 for k, v in mapping.items():
169 for sd_part, hf_part in vae_conversion_map:
170 v = v.replace(hf_part, sd_part)
171 mapping[k] = v
172 for k, v in mapping.items():
173 if "attentions" in k:
174 for sd_part, hf_part in vae_conversion_map_attn:
175 v = v.replace(hf_part, sd_part)
176 mapping[k] = v
177 new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178 weights_to_convert = ["q", "k", "v", "proj_out"]
179 for k, v in new_state_dict.items():
180 for weight_name in weights_to_convert:
181 if f"mid.attn_1.{weight_name}.weight" in k:
182 print(f"Reshaping {k} for SD format")
183 new_state_dict[k] = reshape_weight_for_sd(v)
184 return new_state_dict
185
186
187# =========================#
188# Text Encoder Conversion #
189# =========================#
190# pretty much a no-op
191
192
193def convert_text_enc_state_dict(text_enc_dict):
194 return text_enc_dict
195
196
197if __name__ == "__main__":
198 parser = argparse.ArgumentParser()
199
200 parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
201 parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
202 parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203
204 args = parser.parse_args()
205
206 assert args.model_path is not None, "Must provide a model path!"
207
208 assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209
210 unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211 vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212 text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213
214 # Convert the UNet model
215 unet_state_dict = torch.load(unet_path, map_location="cpu")
216 unet_state_dict = convert_unet_state_dict(unet_state_dict)
217 unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218
219 # Convert the VAE model
220 vae_state_dict = torch.load(vae_path, map_location="cpu")
221 vae_state_dict = convert_vae_state_dict(vae_state_dict)
222 vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223
224 # Convert the text encoder model
225 text_enc_dict = torch.load(text_enc_path, map_location="cpu")
226 text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227 text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228
229 # Put together new checkpoint
230 state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231 if args.half:
232 state_dict = {k: v.half() for k, v in state_dict.items()}
233 state_dict = {"state_dict": state_dict}
234 torch.save(state_dict, args.checkpoint_path)
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py
deleted file mode 100644
index ee7fc33..0000000
--- a/scripts/convert_original_stable_diffusion_to_diffusers.py
+++ /dev/null
@@ -1,690 +0,0 @@
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Conversion script for the LDM checkpoints. """
16
17import argparse
18import os
19
20import torch
21
22
23try:
24 from omegaconf import OmegaConf
25except ImportError:
26 raise ImportError(
27 "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
28 )
29
30from diffusers import (
31 AutoencoderKL,
32 DDIMScheduler,
33 LDMTextToImagePipeline,
34 LMSDiscreteScheduler,
35 PNDMScheduler,
36 StableDiffusionPipeline,
37 UNet2DConditionModel,
38)
39from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
40from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
41from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
42
43
44def shave_segments(path, n_shave_prefix_segments=1):
45 """
46 Removes segments. Positive values shave the first segments, negative shave the last segments.
47 """
48 if n_shave_prefix_segments >= 0:
49 return ".".join(path.split(".")[n_shave_prefix_segments:])
50 else:
51 return ".".join(path.split(".")[:n_shave_prefix_segments])
52
53
54def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
55 """
56 Updates paths inside resnets to the new naming scheme (local renaming)
57 """
58 mapping = []
59 for old_item in old_list:
60 new_item = old_item.replace("in_layers.0", "norm1")
61 new_item = new_item.replace("in_layers.2", "conv1")
62
63 new_item = new_item.replace("out_layers.0", "norm2")
64 new_item = new_item.replace("out_layers.3", "conv2")
65
66 new_item = new_item.replace("emb_layers.1", "time_emb_proj")
67 new_item = new_item.replace("skip_connection", "conv_shortcut")
68
69 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
70
71 mapping.append({"old": old_item, "new": new_item})
72
73 return mapping
74
75
76def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
77 """
78 Updates paths inside resnets to the new naming scheme (local renaming)
79 """
80 mapping = []
81 for old_item in old_list:
82 new_item = old_item
83
84 new_item = new_item.replace("nin_shortcut", "conv_shortcut")
85 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
86
87 mapping.append({"old": old_item, "new": new_item})
88
89 return mapping
90
91
92def renew_attention_paths(old_list, n_shave_prefix_segments=0):
93 """
94 Updates paths inside attentions to the new naming scheme (local renaming)
95 """
96 mapping = []
97 for old_item in old_list:
98 new_item = old_item
99
100 # new_item = new_item.replace('norm.weight', 'group_norm.weight')
101 # new_item = new_item.replace('norm.bias', 'group_norm.bias')
102
103 # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
104 # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
105
106 # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
107
108 mapping.append({"old": old_item, "new": new_item})
109
110 return mapping
111
112
113def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
114 """
115 Updates paths inside attentions to the new naming scheme (local renaming)
116 """
117 mapping = []
118 for old_item in old_list:
119 new_item = old_item
120
121 new_item = new_item.replace("norm.weight", "group_norm.weight")
122 new_item = new_item.replace("norm.bias", "group_norm.bias")
123
124 new_item = new_item.replace("q.weight", "query.weight")
125 new_item = new_item.replace("q.bias", "query.bias")
126
127 new_item = new_item.replace("k.weight", "key.weight")
128 new_item = new_item.replace("k.bias", "key.bias")
129
130 new_item = new_item.replace("v.weight", "value.weight")
131 new_item = new_item.replace("v.bias", "value.bias")
132
133 new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
134 new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
135
136 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
137
138 mapping.append({"old": old_item, "new": new_item})
139
140 return mapping
141
142
143def assign_to_checkpoint(
144 paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
145):
146 """
147 This does the final conversion step: take locally converted weights and apply a global renaming
148 to them. It splits attention layers, and takes into account additional replacements
149 that may arise.
150
151 Assigns the weights to the new checkpoint.
152 """
153 assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
154
155 # Splits the attention layers into three variables.
156 if attention_paths_to_split is not None:
157 for path, path_map in attention_paths_to_split.items():
158 old_tensor = old_checkpoint[path]
159 channels = old_tensor.shape[0] // 3
160
161 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
162
163 num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
164
165 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
166 query, key, value = old_tensor.split(channels // num_heads, dim=1)
167
168 checkpoint[path_map["query"]] = query.reshape(target_shape)
169 checkpoint[path_map["key"]] = key.reshape(target_shape)
170 checkpoint[path_map["value"]] = value.reshape(target_shape)
171
172 for path in paths:
173 new_path = path["new"]
174
175 # These have already been assigned
176 if attention_paths_to_split is not None and new_path in attention_paths_to_split:
177 continue
178
179 # Global renaming happens here
180 new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
181 new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
182 new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
183
184 if additional_replacements is not None:
185 for replacement in additional_replacements:
186 new_path = new_path.replace(replacement["old"], replacement["new"])
187
188 # proj_attn.weight has to be converted from conv 1D to linear
189 if "proj_attn.weight" in new_path:
190 checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
191 else:
192 checkpoint[new_path] = old_checkpoint[path["old"]]
193
194
195def conv_attn_to_linear(checkpoint):
196 keys = list(checkpoint.keys())
197 attn_keys = ["query.weight", "key.weight", "value.weight"]
198 for key in keys:
199 if ".".join(key.split(".")[-2:]) in attn_keys:
200 if checkpoint[key].ndim > 2:
201 checkpoint[key] = checkpoint[key][:, :, 0, 0]
202 elif "proj_attn.weight" in key:
203 if checkpoint[key].ndim > 2:
204 checkpoint[key] = checkpoint[key][:, :, 0]
205
206
207def create_unet_diffusers_config(original_config):
208 """
209 Creates a config for the diffusers based on the config of the LDM model.
210 """
211 unet_params = original_config.model.params.unet_config.params
212
213 block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
214
215 down_block_types = []
216 resolution = 1
217 for i in range(len(block_out_channels)):
218 block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
219 down_block_types.append(block_type)
220 if i != len(block_out_channels) - 1:
221 resolution *= 2
222
223 up_block_types = []
224 for i in range(len(block_out_channels)):
225 block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
226 up_block_types.append(block_type)
227 resolution //= 2
228
229 config = dict(
230 sample_size=unet_params.image_size,
231 in_channels=unet_params.in_channels,
232 out_channels=unet_params.out_channels,
233 down_block_types=tuple(down_block_types),
234 up_block_types=tuple(up_block_types),
235 block_out_channels=tuple(block_out_channels),
236 layers_per_block=unet_params.num_res_blocks,
237 cross_attention_dim=unet_params.context_dim,
238 attention_head_dim=unet_params.num_heads,
239 )
240
241 return config
242
243
244def create_vae_diffusers_config(original_config):
245 """
246 Creates a config for the diffusers based on the config of the LDM model.
247 """
248 vae_params = original_config.model.params.first_stage_config.params.ddconfig
249 _ = original_config.model.params.first_stage_config.params.embed_dim
250
251 block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
252 down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
253 up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
254
255 config = dict(
256 sample_size=vae_params.resolution,
257 in_channels=vae_params.in_channels,
258 out_channels=vae_params.out_ch,
259 down_block_types=tuple(down_block_types),
260 up_block_types=tuple(up_block_types),
261 block_out_channels=tuple(block_out_channels),
262 latent_channels=vae_params.z_channels,
263 layers_per_block=vae_params.num_res_blocks,
264 )
265 return config
266
267
268def create_diffusers_schedular(original_config):
269 schedular = DDIMScheduler(
270 num_train_timesteps=original_config.model.params.timesteps,
271 beta_start=original_config.model.params.linear_start,
272 beta_end=original_config.model.params.linear_end,
273 beta_schedule="scaled_linear",
274 )
275 return schedular
276
277
278def create_ldm_bert_config(original_config):
279 bert_params = original_config.model.parms.cond_stage_config.params
280 config = LDMBertConfig(
281 d_model=bert_params.n_embed,
282 encoder_layers=bert_params.n_layer,
283 encoder_ffn_dim=bert_params.n_embed * 4,
284 )
285 return config
286
287
288def convert_ldm_unet_checkpoint(checkpoint, config):
289 """
290 Takes a state dict and a config, and returns a converted checkpoint.
291 """
292
293 # extract state_dict for UNet
294 unet_state_dict = {}
295 unet_key = "model.diffusion_model."
296 keys = list(checkpoint.keys())
297 for key in keys:
298 if key.startswith(unet_key):
299 unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
300
301 new_checkpoint = {}
302
303 new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
304 new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
305 new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
306 new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
307
308 new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
309 new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
310
311 new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
312 new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
313 new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
314 new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
315
316 # Retrieves the keys for the input blocks only
317 num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
318 input_blocks = {
319 layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
320 for layer_id in range(num_input_blocks)
321 }
322
323 # Retrieves the keys for the middle blocks only
324 num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
325 middle_blocks = {
326 layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
327 for layer_id in range(num_middle_blocks)
328 }
329
330 # Retrieves the keys for the output blocks only
331 num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
332 output_blocks = {
333 layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
334 for layer_id in range(num_output_blocks)
335 }
336
337 for i in range(1, num_input_blocks):
338 block_id = (i - 1) // (config["layers_per_block"] + 1)
339 layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
340
341 resnets = [
342 key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
343 ]
344 attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
345
346 if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
347 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
348 f"input_blocks.{i}.0.op.weight"
349 )
350 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
351 f"input_blocks.{i}.0.op.bias"
352 )
353
354 paths = renew_resnet_paths(resnets)
355 meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
356 assign_to_checkpoint(
357 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
358 )
359
360 if len(attentions):
361 paths = renew_attention_paths(attentions)
362 meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
363 assign_to_checkpoint(
364 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
365 )
366
367 resnet_0 = middle_blocks[0]
368 attentions = middle_blocks[1]
369 resnet_1 = middle_blocks[2]
370
371 resnet_0_paths = renew_resnet_paths(resnet_0)
372 assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
373
374 resnet_1_paths = renew_resnet_paths(resnet_1)
375 assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
376
377 attentions_paths = renew_attention_paths(attentions)
378 meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
379 assign_to_checkpoint(
380 attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
381 )
382
383 for i in range(num_output_blocks):
384 block_id = i // (config["layers_per_block"] + 1)
385 layer_in_block_id = i % (config["layers_per_block"] + 1)
386 output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
387 output_block_list = {}
388
389 for layer in output_block_layers:
390 layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
391 if layer_id in output_block_list:
392 output_block_list[layer_id].append(layer_name)
393 else:
394 output_block_list[layer_id] = [layer_name]
395
396 if len(output_block_list) > 1:
397 resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
398 attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
399
400 resnet_0_paths = renew_resnet_paths(resnets)
401 paths = renew_resnet_paths(resnets)
402
403 meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
404 assign_to_checkpoint(
405 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
406 )
407
408 if ["conv.weight", "conv.bias"] in output_block_list.values():
409 index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
410 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
411 f"output_blocks.{i}.{index}.conv.weight"
412 ]
413 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
414 f"output_blocks.{i}.{index}.conv.bias"
415 ]
416
417 # Clear attentions as they have been attributed above.
418 if len(attentions) == 2:
419 attentions = []
420
421 if len(attentions):
422 paths = renew_attention_paths(attentions)
423 meta_path = {
424 "old": f"output_blocks.{i}.1",
425 "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
426 }
427 assign_to_checkpoint(
428 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
429 )
430 else:
431 resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
432 for path in resnet_0_paths:
433 old_path = ".".join(["output_blocks", str(i), path["old"]])
434 new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
435
436 new_checkpoint[new_path] = unet_state_dict[old_path]
437
438 return new_checkpoint
439
440
441def convert_ldm_vae_checkpoint(checkpoint, config):
442 # extract state dict for VAE
443 vae_state_dict = {}
444 vae_key = "first_stage_model."
445 keys = list(checkpoint.keys())
446 for key in keys:
447 if key.startswith(vae_key):
448 vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
449
450 new_checkpoint = {}
451
452 new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
453 new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
454 new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
455 new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
456 new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
457 new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
458
459 new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
460 new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
461 new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
462 new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
463 new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
464 new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
465
466 new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
467 new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
468 new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
469 new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
470
471 # Retrieves the keys for the encoder down blocks only
472 num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
473 down_blocks = {
474 layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
475 }
476
477 # Retrieves the keys for the decoder up blocks only
478 num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
479 up_blocks = {
480 layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
481 }
482
483 for i in range(num_down_blocks):
484 resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
485
486 if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
487 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
488 f"encoder.down.{i}.downsample.conv.weight"
489 )
490 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
491 f"encoder.down.{i}.downsample.conv.bias"
492 )
493
494 paths = renew_vae_resnet_paths(resnets)
495 meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
496 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
497
498 mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
499 num_mid_res_blocks = 2
500 for i in range(1, num_mid_res_blocks + 1):
501 resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
502
503 paths = renew_vae_resnet_paths(resnets)
504 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
505 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
506
507 mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
508 paths = renew_vae_attention_paths(mid_attentions)
509 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
510 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
511 conv_attn_to_linear(new_checkpoint)
512
513 for i in range(num_up_blocks):
514 block_id = num_up_blocks - 1 - i
515 resnets = [
516 key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
517 ]
518
519 if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
520 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
521 f"decoder.up.{block_id}.upsample.conv.weight"
522 ]
523 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
524 f"decoder.up.{block_id}.upsample.conv.bias"
525 ]
526
527 paths = renew_vae_resnet_paths(resnets)
528 meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
529 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
530
531 mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
532 num_mid_res_blocks = 2
533 for i in range(1, num_mid_res_blocks + 1):
534 resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
535
536 paths = renew_vae_resnet_paths(resnets)
537 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
538 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
539
540 mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
541 paths = renew_vae_attention_paths(mid_attentions)
542 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
543 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
544 conv_attn_to_linear(new_checkpoint)
545 return new_checkpoint
546
547
548def convert_ldm_bert_checkpoint(checkpoint, config):
549 def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
550 hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
551 hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
552 hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
553
554 hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
555 hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
556
557 def _copy_linear(hf_linear, pt_linear):
558 hf_linear.weight = pt_linear.weight
559 hf_linear.bias = pt_linear.bias
560
561 def _copy_layer(hf_layer, pt_layer):
562 # copy layer norms
563 _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
564 _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
565
566 # copy attn
567 _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
568
569 # copy MLP
570 pt_mlp = pt_layer[1][1]
571 _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
572 _copy_linear(hf_layer.fc2, pt_mlp.net[2])
573
574 def _copy_layers(hf_layers, pt_layers):
575 for i, hf_layer in enumerate(hf_layers):
576 if i != 0:
577 i += i
578 pt_layer = pt_layers[i : i + 2]
579 _copy_layer(hf_layer, pt_layer)
580
581 hf_model = LDMBertModel(config).eval()
582
583 # copy embeds
584 hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
585 hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
586
587 # copy layer norm
588 _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
589
590 # copy hidden layers
591 _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
592
593 _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
594
595 return hf_model
596
597
598if __name__ == "__main__":
599 parser = argparse.ArgumentParser()
600
601 parser.add_argument(
602 "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
603 )
604 # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
605 parser.add_argument(
606 "--original_config_file",
607 default=None,
608 type=str,
609 help="The YAML config file corresponding to the original architecture.",
610 )
611 parser.add_argument(
612 "--scheduler_type",
613 default="pndm",
614 type=str,
615 help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
616 )
617 parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
618
619 args = parser.parse_args()
620
621 if args.original_config_file is None:
622 os.system(
623 "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
624 )
625 args.original_config_file = "./v1-inference.yaml"
626
627 original_config = OmegaConf.load(args.original_config_file)
628 checkpoint = torch.load(args.checkpoint_path)["state_dict"]
629
630 num_train_timesteps = original_config.model.params.timesteps
631 beta_start = original_config.model.params.linear_start
632 beta_end = original_config.model.params.linear_end
633 if args.scheduler_type == "pndm":
634 scheduler = PNDMScheduler(
635 beta_end=beta_end,
636 beta_schedule="scaled_linear",
637 beta_start=beta_start,
638 num_train_timesteps=num_train_timesteps,
639 skip_prk_steps=True,
640 )
641 elif args.scheduler_type == "lms":
642 scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
643 elif args.scheduler_type == "ddim":
644 scheduler = DDIMScheduler(
645 beta_start=beta_start,
646 beta_end=beta_end,
647 beta_schedule="scaled_linear",
648 clip_sample=False,
649 set_alpha_to_one=False,
650 )
651 else:
652 raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
653
654 # Convert the UNet2DConditionModel model.
655 unet_config = create_unet_diffusers_config(original_config)
656 converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
657
658 unet = UNet2DConditionModel(**unet_config)
659 unet.load_state_dict(converted_unet_checkpoint)
660
661 # Convert the VAE model.
662 vae_config = create_vae_diffusers_config(original_config)
663 converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
664
665 vae = AutoencoderKL(**vae_config)
666 vae.load_state_dict(converted_vae_checkpoint)
667
668 # Convert the text model.
669 text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
670 if text_model_type == "FrozenCLIPEmbedder":
671 text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
672 tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
673 safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
674 feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
675 pipe = StableDiffusionPipeline(
676 vae=vae,
677 text_encoder=text_model,
678 tokenizer=tokenizer,
679 unet=unet,
680 scheduler=scheduler,
681 safety_checker=safety_checker,
682 feature_extractor=feature_extractor,
683 )
684 else:
685 text_config = create_ldm_bert_config(original_config)
686 text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
687 tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
688 pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
689
690 pipe.save_pretrained(args.dump_path)