diff options
author | Volpeon <git@volpeon.ink> | 2022-09-26 16:36:42 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-09-26 16:36:42 +0200 |
commit | 5588b93859c4380082a7e46bf5bef2119ec1907a (patch) | |
tree | 05a8292201912eb6f417eb2740c86df2153d1095 /scripts | |
download | textual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.tar.gz textual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.tar.bz2 textual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.zip |
Init
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 |
1 files changed, 690 insertions, 0 deletions
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100644 index 0000000..ee7fc33 --- /dev/null +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py | |||
@@ -0,0 +1,690 @@ | |||
1 | # coding=utf-8 | ||
2 | # Copyright 2022 The HuggingFace Inc. team. | ||
3 | # | ||
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | # you may not use this file except in compliance with the License. | ||
6 | # You may obtain a copy of the License at | ||
7 | # | ||
8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | # | ||
10 | # Unless required by applicable law or agreed to in writing, software | ||
11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | # See the License for the specific language governing permissions and | ||
14 | # limitations under the License. | ||
15 | """ Conversion script for the LDM checkpoints. """ | ||
16 | |||
17 | import argparse | ||
18 | import os | ||
19 | |||
20 | import torch | ||
21 | |||
22 | |||
23 | try: | ||
24 | from omegaconf import OmegaConf | ||
25 | except ImportError: | ||
26 | raise ImportError( | ||
27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
28 | ) | ||
29 | |||
30 | from diffusers import ( | ||
31 | AutoencoderKL, | ||
32 | DDIMScheduler, | ||
33 | LDMTextToImagePipeline, | ||
34 | LMSDiscreteScheduler, | ||
35 | PNDMScheduler, | ||
36 | StableDiffusionPipeline, | ||
37 | UNet2DConditionModel, | ||
38 | ) | ||
39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
42 | |||
43 | |||
44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
45 | """ | ||
46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
47 | """ | ||
48 | if n_shave_prefix_segments >= 0: | ||
49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
50 | else: | ||
51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
52 | |||
53 | |||
54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
55 | """ | ||
56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
57 | """ | ||
58 | mapping = [] | ||
59 | for old_item in old_list: | ||
60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
62 | |||
63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
65 | |||
66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
68 | |||
69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
70 | |||
71 | mapping.append({"old": old_item, "new": new_item}) | ||
72 | |||
73 | return mapping | ||
74 | |||
75 | |||
76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
77 | """ | ||
78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
79 | """ | ||
80 | mapping = [] | ||
81 | for old_item in old_list: | ||
82 | new_item = old_item | ||
83 | |||
84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
86 | |||
87 | mapping.append({"old": old_item, "new": new_item}) | ||
88 | |||
89 | return mapping | ||
90 | |||
91 | |||
92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
93 | """ | ||
94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
95 | """ | ||
96 | mapping = [] | ||
97 | for old_item in old_list: | ||
98 | new_item = old_item | ||
99 | |||
100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
102 | |||
103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
105 | |||
106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
107 | |||
108 | mapping.append({"old": old_item, "new": new_item}) | ||
109 | |||
110 | return mapping | ||
111 | |||
112 | |||
113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
114 | """ | ||
115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
116 | """ | ||
117 | mapping = [] | ||
118 | for old_item in old_list: | ||
119 | new_item = old_item | ||
120 | |||
121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
123 | |||
124 | new_item = new_item.replace("q.weight", "query.weight") | ||
125 | new_item = new_item.replace("q.bias", "query.bias") | ||
126 | |||
127 | new_item = new_item.replace("k.weight", "key.weight") | ||
128 | new_item = new_item.replace("k.bias", "key.bias") | ||
129 | |||
130 | new_item = new_item.replace("v.weight", "value.weight") | ||
131 | new_item = new_item.replace("v.bias", "value.bias") | ||
132 | |||
133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
135 | |||
136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
137 | |||
138 | mapping.append({"old": old_item, "new": new_item}) | ||
139 | |||
140 | return mapping | ||
141 | |||
142 | |||
143 | def assign_to_checkpoint( | ||
144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
145 | ): | ||
146 | """ | ||
147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
148 | to them. It splits attention layers, and takes into account additional replacements | ||
149 | that may arise. | ||
150 | |||
151 | Assigns the weights to the new checkpoint. | ||
152 | """ | ||
153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
154 | |||
155 | # Splits the attention layers into three variables. | ||
156 | if attention_paths_to_split is not None: | ||
157 | for path, path_map in attention_paths_to_split.items(): | ||
158 | old_tensor = old_checkpoint[path] | ||
159 | channels = old_tensor.shape[0] // 3 | ||
160 | |||
161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
162 | |||
163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
164 | |||
165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
167 | |||
168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
171 | |||
172 | for path in paths: | ||
173 | new_path = path["new"] | ||
174 | |||
175 | # These have already been assigned | ||
176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
177 | continue | ||
178 | |||
179 | # Global renaming happens here | ||
180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
183 | |||
184 | if additional_replacements is not None: | ||
185 | for replacement in additional_replacements: | ||
186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
187 | |||
188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
189 | if "proj_attn.weight" in new_path: | ||
190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
191 | else: | ||
192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
193 | |||
194 | |||
195 | def conv_attn_to_linear(checkpoint): | ||
196 | keys = list(checkpoint.keys()) | ||
197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
198 | for key in keys: | ||
199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
200 | if checkpoint[key].ndim > 2: | ||
201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
202 | elif "proj_attn.weight" in key: | ||
203 | if checkpoint[key].ndim > 2: | ||
204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
205 | |||
206 | |||
207 | def create_unet_diffusers_config(original_config): | ||
208 | """ | ||
209 | Creates a config for the diffusers based on the config of the LDM model. | ||
210 | """ | ||
211 | unet_params = original_config.model.params.unet_config.params | ||
212 | |||
213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
214 | |||
215 | down_block_types = [] | ||
216 | resolution = 1 | ||
217 | for i in range(len(block_out_channels)): | ||
218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
219 | down_block_types.append(block_type) | ||
220 | if i != len(block_out_channels) - 1: | ||
221 | resolution *= 2 | ||
222 | |||
223 | up_block_types = [] | ||
224 | for i in range(len(block_out_channels)): | ||
225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
226 | up_block_types.append(block_type) | ||
227 | resolution //= 2 | ||
228 | |||
229 | config = dict( | ||
230 | sample_size=unet_params.image_size, | ||
231 | in_channels=unet_params.in_channels, | ||
232 | out_channels=unet_params.out_channels, | ||
233 | down_block_types=tuple(down_block_types), | ||
234 | up_block_types=tuple(up_block_types), | ||
235 | block_out_channels=tuple(block_out_channels), | ||
236 | layers_per_block=unet_params.num_res_blocks, | ||
237 | cross_attention_dim=unet_params.context_dim, | ||
238 | attention_head_dim=unet_params.num_heads, | ||
239 | ) | ||
240 | |||
241 | return config | ||
242 | |||
243 | |||
244 | def create_vae_diffusers_config(original_config): | ||
245 | """ | ||
246 | Creates a config for the diffusers based on the config of the LDM model. | ||
247 | """ | ||
248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
250 | |||
251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
254 | |||
255 | config = dict( | ||
256 | sample_size=vae_params.resolution, | ||
257 | in_channels=vae_params.in_channels, | ||
258 | out_channels=vae_params.out_ch, | ||
259 | down_block_types=tuple(down_block_types), | ||
260 | up_block_types=tuple(up_block_types), | ||
261 | block_out_channels=tuple(block_out_channels), | ||
262 | latent_channels=vae_params.z_channels, | ||
263 | layers_per_block=vae_params.num_res_blocks, | ||
264 | ) | ||
265 | return config | ||
266 | |||
267 | |||
268 | def create_diffusers_schedular(original_config): | ||
269 | schedular = DDIMScheduler( | ||
270 | num_train_timesteps=original_config.model.params.timesteps, | ||
271 | beta_start=original_config.model.params.linear_start, | ||
272 | beta_end=original_config.model.params.linear_end, | ||
273 | beta_schedule="scaled_linear", | ||
274 | ) | ||
275 | return schedular | ||
276 | |||
277 | |||
278 | def create_ldm_bert_config(original_config): | ||
279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
280 | config = LDMBertConfig( | ||
281 | d_model=bert_params.n_embed, | ||
282 | encoder_layers=bert_params.n_layer, | ||
283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
284 | ) | ||
285 | return config | ||
286 | |||
287 | |||
288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
289 | """ | ||
290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
291 | """ | ||
292 | |||
293 | # extract state_dict for UNet | ||
294 | unet_state_dict = {} | ||
295 | unet_key = "model.diffusion_model." | ||
296 | keys = list(checkpoint.keys()) | ||
297 | for key in keys: | ||
298 | if key.startswith(unet_key): | ||
299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
300 | |||
301 | new_checkpoint = {} | ||
302 | |||
303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
307 | |||
308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
310 | |||
311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
315 | |||
316 | # Retrieves the keys for the input blocks only | ||
317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
318 | input_blocks = { | ||
319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
320 | for layer_id in range(num_input_blocks) | ||
321 | } | ||
322 | |||
323 | # Retrieves the keys for the middle blocks only | ||
324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
325 | middle_blocks = { | ||
326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
327 | for layer_id in range(num_middle_blocks) | ||
328 | } | ||
329 | |||
330 | # Retrieves the keys for the output blocks only | ||
331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
332 | output_blocks = { | ||
333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
334 | for layer_id in range(num_output_blocks) | ||
335 | } | ||
336 | |||
337 | for i in range(1, num_input_blocks): | ||
338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
340 | |||
341 | resnets = [ | ||
342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
343 | ] | ||
344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
345 | |||
346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
348 | f"input_blocks.{i}.0.op.weight" | ||
349 | ) | ||
350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
351 | f"input_blocks.{i}.0.op.bias" | ||
352 | ) | ||
353 | |||
354 | paths = renew_resnet_paths(resnets) | ||
355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
356 | assign_to_checkpoint( | ||
357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
358 | ) | ||
359 | |||
360 | if len(attentions): | ||
361 | paths = renew_attention_paths(attentions) | ||
362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
363 | assign_to_checkpoint( | ||
364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
365 | ) | ||
366 | |||
367 | resnet_0 = middle_blocks[0] | ||
368 | attentions = middle_blocks[1] | ||
369 | resnet_1 = middle_blocks[2] | ||
370 | |||
371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
373 | |||
374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
376 | |||
377 | attentions_paths = renew_attention_paths(attentions) | ||
378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
379 | assign_to_checkpoint( | ||
380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
381 | ) | ||
382 | |||
383 | for i in range(num_output_blocks): | ||
384 | block_id = i // (config["layers_per_block"] + 1) | ||
385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
387 | output_block_list = {} | ||
388 | |||
389 | for layer in output_block_layers: | ||
390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
391 | if layer_id in output_block_list: | ||
392 | output_block_list[layer_id].append(layer_name) | ||
393 | else: | ||
394 | output_block_list[layer_id] = [layer_name] | ||
395 | |||
396 | if len(output_block_list) > 1: | ||
397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
399 | |||
400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
401 | paths = renew_resnet_paths(resnets) | ||
402 | |||
403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
404 | assign_to_checkpoint( | ||
405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
406 | ) | ||
407 | |||
408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
411 | f"output_blocks.{i}.{index}.conv.weight" | ||
412 | ] | ||
413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
414 | f"output_blocks.{i}.{index}.conv.bias" | ||
415 | ] | ||
416 | |||
417 | # Clear attentions as they have been attributed above. | ||
418 | if len(attentions) == 2: | ||
419 | attentions = [] | ||
420 | |||
421 | if len(attentions): | ||
422 | paths = renew_attention_paths(attentions) | ||
423 | meta_path = { | ||
424 | "old": f"output_blocks.{i}.1", | ||
425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
426 | } | ||
427 | assign_to_checkpoint( | ||
428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
429 | ) | ||
430 | else: | ||
431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
432 | for path in resnet_0_paths: | ||
433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
435 | |||
436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
437 | |||
438 | return new_checkpoint | ||
439 | |||
440 | |||
441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
442 | # extract state dict for VAE | ||
443 | vae_state_dict = {} | ||
444 | vae_key = "first_stage_model." | ||
445 | keys = list(checkpoint.keys()) | ||
446 | for key in keys: | ||
447 | if key.startswith(vae_key): | ||
448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
449 | |||
450 | new_checkpoint = {} | ||
451 | |||
452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
458 | |||
459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
465 | |||
466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
470 | |||
471 | # Retrieves the keys for the encoder down blocks only | ||
472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
473 | down_blocks = { | ||
474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
475 | } | ||
476 | |||
477 | # Retrieves the keys for the decoder up blocks only | ||
478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
479 | up_blocks = { | ||
480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
481 | } | ||
482 | |||
483 | for i in range(num_down_blocks): | ||
484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
485 | |||
486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
488 | f"encoder.down.{i}.downsample.conv.weight" | ||
489 | ) | ||
490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
491 | f"encoder.down.{i}.downsample.conv.bias" | ||
492 | ) | ||
493 | |||
494 | paths = renew_vae_resnet_paths(resnets) | ||
495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
497 | |||
498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
499 | num_mid_res_blocks = 2 | ||
500 | for i in range(1, num_mid_res_blocks + 1): | ||
501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
502 | |||
503 | paths = renew_vae_resnet_paths(resnets) | ||
504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
506 | |||
507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
508 | paths = renew_vae_attention_paths(mid_attentions) | ||
509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
511 | conv_attn_to_linear(new_checkpoint) | ||
512 | |||
513 | for i in range(num_up_blocks): | ||
514 | block_id = num_up_blocks - 1 - i | ||
515 | resnets = [ | ||
516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
517 | ] | ||
518 | |||
519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
522 | ] | ||
523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
525 | ] | ||
526 | |||
527 | paths = renew_vae_resnet_paths(resnets) | ||
528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
530 | |||
531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
532 | num_mid_res_blocks = 2 | ||
533 | for i in range(1, num_mid_res_blocks + 1): | ||
534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
535 | |||
536 | paths = renew_vae_resnet_paths(resnets) | ||
537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
539 | |||
540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
541 | paths = renew_vae_attention_paths(mid_attentions) | ||
542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
544 | conv_attn_to_linear(new_checkpoint) | ||
545 | return new_checkpoint | ||
546 | |||
547 | |||
548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
553 | |||
554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
556 | |||
557 | def _copy_linear(hf_linear, pt_linear): | ||
558 | hf_linear.weight = pt_linear.weight | ||
559 | hf_linear.bias = pt_linear.bias | ||
560 | |||
561 | def _copy_layer(hf_layer, pt_layer): | ||
562 | # copy layer norms | ||
563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
565 | |||
566 | # copy attn | ||
567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
568 | |||
569 | # copy MLP | ||
570 | pt_mlp = pt_layer[1][1] | ||
571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
573 | |||
574 | def _copy_layers(hf_layers, pt_layers): | ||
575 | for i, hf_layer in enumerate(hf_layers): | ||
576 | if i != 0: | ||
577 | i += i | ||
578 | pt_layer = pt_layers[i : i + 2] | ||
579 | _copy_layer(hf_layer, pt_layer) | ||
580 | |||
581 | hf_model = LDMBertModel(config).eval() | ||
582 | |||
583 | # copy embeds | ||
584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
586 | |||
587 | # copy layer norm | ||
588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
589 | |||
590 | # copy hidden layers | ||
591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
592 | |||
593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
594 | |||
595 | return hf_model | ||
596 | |||
597 | |||
598 | if __name__ == "__main__": | ||
599 | parser = argparse.ArgumentParser() | ||
600 | |||
601 | parser.add_argument( | ||
602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
603 | ) | ||
604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
605 | parser.add_argument( | ||
606 | "--original_config_file", | ||
607 | default=None, | ||
608 | type=str, | ||
609 | help="The YAML config file corresponding to the original architecture.", | ||
610 | ) | ||
611 | parser.add_argument( | ||
612 | "--scheduler_type", | ||
613 | default="pndm", | ||
614 | type=str, | ||
615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
616 | ) | ||
617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
618 | |||
619 | args = parser.parse_args() | ||
620 | |||
621 | if args.original_config_file is None: | ||
622 | os.system( | ||
623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
624 | ) | ||
625 | args.original_config_file = "./v1-inference.yaml" | ||
626 | |||
627 | original_config = OmegaConf.load(args.original_config_file) | ||
628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
629 | |||
630 | num_train_timesteps = original_config.model.params.timesteps | ||
631 | beta_start = original_config.model.params.linear_start | ||
632 | beta_end = original_config.model.params.linear_end | ||
633 | if args.scheduler_type == "pndm": | ||
634 | scheduler = PNDMScheduler( | ||
635 | beta_end=beta_end, | ||
636 | beta_schedule="scaled_linear", | ||
637 | beta_start=beta_start, | ||
638 | num_train_timesteps=num_train_timesteps, | ||
639 | skip_prk_steps=True, | ||
640 | ) | ||
641 | elif args.scheduler_type == "lms": | ||
642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
643 | elif args.scheduler_type == "ddim": | ||
644 | scheduler = DDIMScheduler( | ||
645 | beta_start=beta_start, | ||
646 | beta_end=beta_end, | ||
647 | beta_schedule="scaled_linear", | ||
648 | clip_sample=False, | ||
649 | set_alpha_to_one=False, | ||
650 | ) | ||
651 | else: | ||
652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
653 | |||
654 | # Convert the UNet2DConditionModel model. | ||
655 | unet_config = create_unet_diffusers_config(original_config) | ||
656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
657 | |||
658 | unet = UNet2DConditionModel(**unet_config) | ||
659 | unet.load_state_dict(converted_unet_checkpoint) | ||
660 | |||
661 | # Convert the VAE model. | ||
662 | vae_config = create_vae_diffusers_config(original_config) | ||
663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
664 | |||
665 | vae = AutoencoderKL(**vae_config) | ||
666 | vae.load_state_dict(converted_vae_checkpoint) | ||
667 | |||
668 | # Convert the text model. | ||
669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
670 | if text_model_type == "FrozenCLIPEmbedder": | ||
671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
675 | pipe = StableDiffusionPipeline( | ||
676 | vae=vae, | ||
677 | text_encoder=text_model, | ||
678 | tokenizer=tokenizer, | ||
679 | unet=unet, | ||
680 | scheduler=scheduler, | ||
681 | safety_checker=safety_checker, | ||
682 | feature_extractor=feature_extractor, | ||
683 | ) | ||
684 | else: | ||
685 | text_config = create_ldm_bert_config(original_config) | ||
686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
689 | |||
690 | pipe.save_pretrained(args.dump_path) | ||