diff options
-rw-r--r-- | data/csv.py | 21 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 61 | ||||
-rw-r--r-- | scripts/convert_diffusers_to_original_stable_diffusion.py | 234 | ||||
-rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 | ||||
-rw-r--r-- | train_dreambooth.py | 14 | ||||
-rw-r--r-- | train_lora.py | 16 | ||||
-rw-r--r-- | train_ti.py | 14 | ||||
-rw-r--r-- | training/functional.py | 36 |
8 files changed, 99 insertions, 987 deletions
diff --git a/data/csv.py b/data/csv.py index fba5d4b..a6cd065 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -99,14 +99,16 @@ def generate_buckets( | |||
99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
100 | 100 | ||
101 | 101 | ||
102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): | 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): |
103 | prompt_ids = [example["prompt_ids"] for example in examples] | 103 | prompt_ids = [example["prompt_ids"] for example in examples] |
104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
105 | 105 | ||
106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] |
107 | pixel_values = [example["instance_images"] for example in examples] | 107 | pixel_values = [example["instance_images"] for example in examples] |
108 | 108 | ||
109 | if with_prior_preservation: | 109 | if with_guidance: |
110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
111 | elif with_prior_preservation: | ||
110 | input_ids += [example["class_prompt_ids"] for example in examples] | 112 | input_ids += [example["class_prompt_ids"] for example in examples] |
111 | pixel_values += [example["class_images"] for example in examples] | 113 | pixel_values += [example["class_images"] for example in examples] |
112 | 114 | ||
@@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple): | |||
133 | class_image_path: Path | 135 | class_image_path: Path |
134 | prompt: list[str] | 136 | prompt: list[str] |
135 | cprompt: str | 137 | cprompt: str |
136 | nprompt: str | 138 | nprompt: list[str] |
137 | collection: list[str] | 139 | collection: list[str] |
138 | 140 | ||
139 | 141 | ||
@@ -163,6 +165,7 @@ class VlpnDataModule(): | |||
163 | data_file: str, | 165 | data_file: str, |
164 | tokenizer: CLIPTokenizer, | 166 | tokenizer: CLIPTokenizer, |
165 | class_subdir: str = "cls", | 167 | class_subdir: str = "cls", |
168 | with_guidance: bool = False, | ||
166 | num_class_images: int = 1, | 169 | num_class_images: int = 1, |
167 | size: int = 768, | 170 | size: int = 768, |
168 | num_buckets: int = 0, | 171 | num_buckets: int = 0, |
@@ -191,6 +194,7 @@ class VlpnDataModule(): | |||
191 | self.class_root = self.data_root / class_subdir | 194 | self.class_root = self.data_root / class_subdir |
192 | self.class_root.mkdir(parents=True, exist_ok=True) | 195 | self.class_root.mkdir(parents=True, exist_ok=True) |
193 | self.num_class_images = num_class_images | 196 | self.num_class_images = num_class_images |
197 | self.with_guidance = with_guidance | ||
194 | 198 | ||
195 | self.tokenizer = tokenizer | 199 | self.tokenizer = tokenizer |
196 | self.size = size | 200 | self.size = size |
@@ -228,10 +232,10 @@ class VlpnDataModule(): | |||
228 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
229 | expansions | 233 | expansions |
230 | )), | 234 | )), |
231 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_to_keywords( |
232 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 236 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
233 | expansions | 237 | expansions |
234 | )), | 238 | ), |
235 | item["collection"].split(", ") if "collection" in item else [] | 239 | item["collection"].split(", ") if "collection" in item else [] |
236 | ) | 240 | ) |
237 | for item in data | 241 | for item in data |
@@ -279,7 +283,7 @@ class VlpnDataModule(): | |||
279 | if self.seed is not None: | 283 | if self.seed is not None: |
280 | generator = generator.manual_seed(self.seed) | 284 | generator = generator.manual_seed(self.seed) |
281 | 285 | ||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 286 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
283 | 287 | ||
284 | if valid_set_size == 0: | 288 | if valid_set_size == 0: |
285 | data_train, data_val = items, items | 289 | data_train, data_val = items, items |
@@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset): | |||
443 | example = {} | 447 | example = {} |
444 | 448 | ||
445 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 449 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
446 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 450 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) |
447 | 451 | ||
448 | example["instance_prompt_ids"] = self.get_input_ids( | 452 | example["instance_prompt_ids"] = self.get_input_ids( |
449 | keywords_to_prompt(item.prompt, self.dropout, True) | 453 | keywords_to_prompt(item.prompt, self.dropout, True) |
450 | ) | 454 | ) |
455 | example["negative_prompt_ids"] = self.get_input_ids( | ||
456 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
457 | ) | ||
451 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 458 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
452 | 459 | ||
453 | if self.num_class_images != 0: | 460 | if self.num_class_images != 0: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ea2a656..127ca50 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
307 | 307 | ||
308 | return timesteps, num_inference_steps - t_start | 308 | return timesteps, num_inference_steps - t_start |
309 | 309 | ||
310 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): |
311 | return (1.4 * perlin_noise( | 311 | offset_image = perlin_noise( |
312 | (batch_size, 1, width, height), | 312 | (batch_size, 1, width, height), |
313 | res=1, | 313 | res=1, |
314 | octaves=4, | ||
315 | generator=generator, | 314 | generator=generator, |
316 | dtype=dtype, | 315 | dtype=dtype, |
317 | device=device | 316 | device=device |
318 | )).clamp(-1, 1).expand(batch_size, 3, width, height) | 317 | ) |
318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | ||
319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
320 | return offset_latents | ||
319 | 321 | ||
320 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): |
321 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
322 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
323 | init_latents = self.vae.config.scaling_factor * init_latents | 325 | latents = self.vae.config.scaling_factor * latents |
324 | 326 | ||
325 | if batch_size % init_latents.shape[0] != 0: | 327 | if batch_size % latents.shape[0] != 0: |
326 | raise ValueError( | 328 | raise ValueError( |
327 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 329 | f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts." |
328 | ) | 330 | ) |
329 | else: | 331 | else: |
330 | batch_multiplier = batch_size // init_latents.shape[0] | 332 | batch_multiplier = batch_size // latents.shape[0] |
331 | init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) | 333 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
332 | 334 | ||
333 | # add noise to latents using the timesteps | 335 | # add noise to latents using the timesteps |
334 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) |
337 | |||
338 | if brightness_offset != 0: | ||
339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
341 | ) | ||
335 | 342 | ||
336 | # get latents | 343 | # get latents |
337 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | 344 | latents = self.scheduler.add_noise(latents, noise, timestep) |
338 | latents = init_latents | ||
339 | 345 | ||
340 | return latents | 346 | return latents |
341 | 347 | ||
342 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | 348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): |
343 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
344 | if isinstance(generator, list) and len(generator) != batch_size: | 350 | if isinstance(generator, list) and len(generator) != batch_size: |
345 | raise ValueError( | 351 | raise ValueError( |
@@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
352 | else: | 358 | else: |
353 | latents = latents.to(device) | 359 | latents = latents.to(device) |
354 | 360 | ||
361 | if brightness_offset != 0: | ||
362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
363 | batch_size, height, width, dtype, device, generator | ||
364 | ) | ||
365 | |||
355 | # scale the initial noise by the standard deviation required by the scheduler | 366 | # scale the initial noise by the standard deviation required by the scheduler |
356 | latents = latents * self.scheduler.init_noise_sigma | 367 | latents = latents * self.scheduler.init_noise_sigma |
357 | return latents | 368 | return latents |
@@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
395 | sag_scale: float = 0.75, | 406 | sag_scale: float = 0.75, |
396 | eta: float = 0.0, | 407 | eta: float = 0.0, |
397 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
398 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, | 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
399 | output_type: str = "pil", | 411 | output_type: str = "pil", |
400 | return_dict: bool = True, | 412 | return_dict: bool = True, |
401 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
468 | num_channels_latents = self.unet.in_channels | 480 | num_channels_latents = self.unet.in_channels |
469 | do_classifier_free_guidance = guidance_scale > 1.0 | 481 | do_classifier_free_guidance = guidance_scale > 1.0 |
470 | do_self_attention_guidance = sag_scale > 0.0 | 482 | do_self_attention_guidance = sag_scale > 0.0 |
471 | prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" | 483 | prep_from_image = isinstance(image, PIL.Image.Image) |
472 | 484 | ||
473 | # 3. Encode input prompt | 485 | # 3. Encode input prompt |
474 | prompt_embeds = self.encode_prompt( | 486 | prompt_embeds = self.encode_prompt( |
@@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
482 | # 4. Prepare latent variables | 494 | # 4. Prepare latent variables |
483 | if isinstance(image, PIL.Image.Image): | 495 | if isinstance(image, PIL.Image.Image): |
484 | image = preprocess(image) | 496 | image = preprocess(image) |
485 | elif image == "noise": | ||
486 | image = self.prepare_image( | ||
487 | batch_size * num_images_per_prompt, | ||
488 | width, | ||
489 | height, | ||
490 | prompt_embeds.dtype, | ||
491 | device, | ||
492 | generator | ||
493 | ) | ||
494 | 497 | ||
495 | # 5. Prepare timesteps | 498 | # 5. Prepare timesteps |
496 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
@@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
503 | image, | 506 | image, |
504 | latent_timestep, | 507 | latent_timestep, |
505 | batch_size * num_images_per_prompt, | 508 | batch_size * num_images_per_prompt, |
509 | brightness_offset, | ||
506 | prompt_embeds.dtype, | 510 | prompt_embeds.dtype, |
507 | device, | 511 | device, |
508 | generator | 512 | generator, |
509 | ) | 513 | ) |
510 | else: | 514 | else: |
511 | latents = self.prepare_latents( | 515 | latents = self.prepare_latents( |
@@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
513 | num_channels_latents, | 517 | num_channels_latents, |
514 | height, | 518 | height, |
515 | width, | 519 | width, |
520 | brightness_offset, | ||
516 | prompt_embeds.dtype, | 521 | prompt_embeds.dtype, |
517 | device, | 522 | device, |
518 | generator, | 523 | generator, |
519 | image | 524 | image, |
520 | ) | 525 | ) |
521 | 526 | ||
522 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 527 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py deleted file mode 100644 index 9888f62..0000000 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ /dev/null | |||
@@ -1,234 +0,0 @@ | |||
1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. | ||
2 | # *Only* converts the UNet, VAE, and Text Encoder. | ||
3 | # Does not convert optimizer state or any other thing. | ||
4 | |||
5 | import argparse | ||
6 | import os.path as osp | ||
7 | |||
8 | import torch | ||
9 | |||
10 | |||
11 | # =================# | ||
12 | # UNet Conversion # | ||
13 | # =================# | ||
14 | |||
15 | unet_conversion_map = [ | ||
16 | # (stable-diffusion, HF Diffusers) | ||
17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), | ||
18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), | ||
19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), | ||
20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), | ||
21 | ("input_blocks.0.0.weight", "conv_in.weight"), | ||
22 | ("input_blocks.0.0.bias", "conv_in.bias"), | ||
23 | ("out.0.weight", "conv_norm_out.weight"), | ||
24 | ("out.0.bias", "conv_norm_out.bias"), | ||
25 | ("out.2.weight", "conv_out.weight"), | ||
26 | ("out.2.bias", "conv_out.bias"), | ||
27 | ] | ||
28 | |||
29 | unet_conversion_map_resnet = [ | ||
30 | # (stable-diffusion, HF Diffusers) | ||
31 | ("in_layers.0", "norm1"), | ||
32 | ("in_layers.2", "conv1"), | ||
33 | ("out_layers.0", "norm2"), | ||
34 | ("out_layers.3", "conv2"), | ||
35 | ("emb_layers.1", "time_emb_proj"), | ||
36 | ("skip_connection", "conv_shortcut"), | ||
37 | ] | ||
38 | |||
39 | unet_conversion_map_layer = [] | ||
40 | # hardcoded number of downblocks and resnets/attentions... | ||
41 | # would need smarter logic for other networks. | ||
42 | for i in range(4): | ||
43 | # loop over downblocks/upblocks | ||
44 | |||
45 | for j in range(2): | ||
46 | # loop over resnets/attentions for downblocks | ||
47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | ||
48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | ||
49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | ||
50 | |||
51 | if i < 3: | ||
52 | # no attention layers in down_blocks.3 | ||
53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | ||
54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | ||
55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | ||
56 | |||
57 | for j in range(3): | ||
58 | # loop over resnets/attentions for upblocks | ||
59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | ||
60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | ||
61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | ||
62 | |||
63 | if i > 0: | ||
64 | # no attention layers in up_blocks.0 | ||
65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | ||
66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | ||
67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | ||
68 | |||
69 | if i < 3: | ||
70 | # no downsample in down_blocks.3 | ||
71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | ||
72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | ||
73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
74 | |||
75 | # no upsample in up_blocks.3 | ||
76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." | ||
78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
79 | |||
80 | hf_mid_atn_prefix = "mid_block.attentions.0." | ||
81 | sd_mid_atn_prefix = "middle_block.1." | ||
82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | ||
83 | |||
84 | for j in range(2): | ||
85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." | ||
86 | sd_mid_res_prefix = f"middle_block.{2*j}." | ||
87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
88 | |||
89 | |||
90 | def convert_unet_state_dict(unet_state_dict): | ||
91 | # buyer beware: this is a *brittle* function, | ||
92 | # and correct output requires that all of these pieces interact in | ||
93 | # the exact order in which I have arranged them. | ||
94 | mapping = {k: k for k in unet_state_dict.keys()} | ||
95 | for sd_name, hf_name in unet_conversion_map: | ||
96 | mapping[hf_name] = sd_name | ||
97 | for k, v in mapping.items(): | ||
98 | if "resnets" in k: | ||
99 | for sd_part, hf_part in unet_conversion_map_resnet: | ||
100 | v = v.replace(hf_part, sd_part) | ||
101 | mapping[k] = v | ||
102 | for k, v in mapping.items(): | ||
103 | for sd_part, hf_part in unet_conversion_map_layer: | ||
104 | v = v.replace(hf_part, sd_part) | ||
105 | mapping[k] = v | ||
106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} | ||
107 | return new_state_dict | ||
108 | |||
109 | |||
110 | # ================# | ||
111 | # VAE Conversion # | ||
112 | # ================# | ||
113 | |||
114 | vae_conversion_map = [ | ||
115 | # (stable-diffusion, HF Diffusers) | ||
116 | ("nin_shortcut", "conv_shortcut"), | ||
117 | ("norm_out", "conv_norm_out"), | ||
118 | ("mid.attn_1.", "mid_block.attentions.0."), | ||
119 | ] | ||
120 | |||
121 | for i in range(4): | ||
122 | # down_blocks have two resnets | ||
123 | for j in range(2): | ||
124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." | ||
125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." | ||
126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) | ||
127 | |||
128 | if i < 3: | ||
129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." | ||
130 | sd_downsample_prefix = f"down.{i}.downsample." | ||
131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
132 | |||
133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
134 | sd_upsample_prefix = f"up.{3-i}.upsample." | ||
135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
136 | |||
137 | # up_blocks have three resnets | ||
138 | # also, up blocks in hf are numbered in reverse from sd | ||
139 | for j in range(3): | ||
140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." | ||
141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." | ||
142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) | ||
143 | |||
144 | # this part accounts for mid blocks in both the encoder and the decoder | ||
145 | for i in range(2): | ||
146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." | ||
147 | sd_mid_res_prefix = f"mid.block_{i+1}." | ||
148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
149 | |||
150 | |||
151 | vae_conversion_map_attn = [ | ||
152 | # (stable-diffusion, HF Diffusers) | ||
153 | ("norm.", "group_norm."), | ||
154 | ("q.", "query."), | ||
155 | ("k.", "key."), | ||
156 | ("v.", "value."), | ||
157 | ("proj_out.", "proj_attn."), | ||
158 | ] | ||
159 | |||
160 | |||
161 | def reshape_weight_for_sd(w): | ||
162 | # convert HF linear weights to SD conv2d weights | ||
163 | return w.reshape(*w.shape, 1, 1) | ||
164 | |||
165 | |||
166 | def convert_vae_state_dict(vae_state_dict): | ||
167 | mapping = {k: k for k in vae_state_dict.keys()} | ||
168 | for k, v in mapping.items(): | ||
169 | for sd_part, hf_part in vae_conversion_map: | ||
170 | v = v.replace(hf_part, sd_part) | ||
171 | mapping[k] = v | ||
172 | for k, v in mapping.items(): | ||
173 | if "attentions" in k: | ||
174 | for sd_part, hf_part in vae_conversion_map_attn: | ||
175 | v = v.replace(hf_part, sd_part) | ||
176 | mapping[k] = v | ||
177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} | ||
178 | weights_to_convert = ["q", "k", "v", "proj_out"] | ||
179 | for k, v in new_state_dict.items(): | ||
180 | for weight_name in weights_to_convert: | ||
181 | if f"mid.attn_1.{weight_name}.weight" in k: | ||
182 | print(f"Reshaping {k} for SD format") | ||
183 | new_state_dict[k] = reshape_weight_for_sd(v) | ||
184 | return new_state_dict | ||
185 | |||
186 | |||
187 | # =========================# | ||
188 | # Text Encoder Conversion # | ||
189 | # =========================# | ||
190 | # pretty much a no-op | ||
191 | |||
192 | |||
193 | def convert_text_enc_state_dict(text_enc_dict): | ||
194 | return text_enc_dict | ||
195 | |||
196 | |||
197 | if __name__ == "__main__": | ||
198 | parser = argparse.ArgumentParser() | ||
199 | |||
200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") | ||
201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") | ||
202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") | ||
203 | |||
204 | args = parser.parse_args() | ||
205 | |||
206 | assert args.model_path is not None, "Must provide a model path!" | ||
207 | |||
208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" | ||
209 | |||
210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") | ||
211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") | ||
212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") | ||
213 | |||
214 | # Convert the UNet model | ||
215 | unet_state_dict = torch.load(unet_path, map_location="cpu") | ||
216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) | ||
217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} | ||
218 | |||
219 | # Convert the VAE model | ||
220 | vae_state_dict = torch.load(vae_path, map_location="cpu") | ||
221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) | ||
222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | ||
223 | |||
224 | # Convert the text encoder model | ||
225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") | ||
226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | ||
227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} | ||
228 | |||
229 | # Put together new checkpoint | ||
230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | ||
231 | if args.half: | ||
232 | state_dict = {k: v.half() for k, v in state_dict.items()} | ||
233 | state_dict = {"state_dict": state_dict} | ||
234 | torch.save(state_dict, args.checkpoint_path) | ||
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py deleted file mode 100644 index ee7fc33..0000000 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ /dev/null | |||
@@ -1,690 +0,0 @@ | |||
1 | # coding=utf-8 | ||
2 | # Copyright 2022 The HuggingFace Inc. team. | ||
3 | # | ||
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | # you may not use this file except in compliance with the License. | ||
6 | # You may obtain a copy of the License at | ||
7 | # | ||
8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | # | ||
10 | # Unless required by applicable law or agreed to in writing, software | ||
11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | # See the License for the specific language governing permissions and | ||
14 | # limitations under the License. | ||
15 | """ Conversion script for the LDM checkpoints. """ | ||
16 | |||
17 | import argparse | ||
18 | import os | ||
19 | |||
20 | import torch | ||
21 | |||
22 | |||
23 | try: | ||
24 | from omegaconf import OmegaConf | ||
25 | except ImportError: | ||
26 | raise ImportError( | ||
27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
28 | ) | ||
29 | |||
30 | from diffusers import ( | ||
31 | AutoencoderKL, | ||
32 | DDIMScheduler, | ||
33 | LDMTextToImagePipeline, | ||
34 | LMSDiscreteScheduler, | ||
35 | PNDMScheduler, | ||
36 | StableDiffusionPipeline, | ||
37 | UNet2DConditionModel, | ||
38 | ) | ||
39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
42 | |||
43 | |||
44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
45 | """ | ||
46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
47 | """ | ||
48 | if n_shave_prefix_segments >= 0: | ||
49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
50 | else: | ||
51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
52 | |||
53 | |||
54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
55 | """ | ||
56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
57 | """ | ||
58 | mapping = [] | ||
59 | for old_item in old_list: | ||
60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
62 | |||
63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
65 | |||
66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
68 | |||
69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
70 | |||
71 | mapping.append({"old": old_item, "new": new_item}) | ||
72 | |||
73 | return mapping | ||
74 | |||
75 | |||
76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
77 | """ | ||
78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
79 | """ | ||
80 | mapping = [] | ||
81 | for old_item in old_list: | ||
82 | new_item = old_item | ||
83 | |||
84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
86 | |||
87 | mapping.append({"old": old_item, "new": new_item}) | ||
88 | |||
89 | return mapping | ||
90 | |||
91 | |||
92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
93 | """ | ||
94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
95 | """ | ||
96 | mapping = [] | ||
97 | for old_item in old_list: | ||
98 | new_item = old_item | ||
99 | |||
100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
102 | |||
103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
105 | |||
106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
107 | |||
108 | mapping.append({"old": old_item, "new": new_item}) | ||
109 | |||
110 | return mapping | ||
111 | |||
112 | |||
113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
114 | """ | ||
115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
116 | """ | ||
117 | mapping = [] | ||
118 | for old_item in old_list: | ||
119 | new_item = old_item | ||
120 | |||
121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
123 | |||
124 | new_item = new_item.replace("q.weight", "query.weight") | ||
125 | new_item = new_item.replace("q.bias", "query.bias") | ||
126 | |||
127 | new_item = new_item.replace("k.weight", "key.weight") | ||
128 | new_item = new_item.replace("k.bias", "key.bias") | ||
129 | |||
130 | new_item = new_item.replace("v.weight", "value.weight") | ||
131 | new_item = new_item.replace("v.bias", "value.bias") | ||
132 | |||
133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
135 | |||
136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
137 | |||
138 | mapping.append({"old": old_item, "new": new_item}) | ||
139 | |||
140 | return mapping | ||
141 | |||
142 | |||
143 | def assign_to_checkpoint( | ||
144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
145 | ): | ||
146 | """ | ||
147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
148 | to them. It splits attention layers, and takes into account additional replacements | ||
149 | that may arise. | ||
150 | |||
151 | Assigns the weights to the new checkpoint. | ||
152 | """ | ||
153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
154 | |||
155 | # Splits the attention layers into three variables. | ||
156 | if attention_paths_to_split is not None: | ||
157 | for path, path_map in attention_paths_to_split.items(): | ||
158 | old_tensor = old_checkpoint[path] | ||
159 | channels = old_tensor.shape[0] // 3 | ||
160 | |||
161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
162 | |||
163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
164 | |||
165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
167 | |||
168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
171 | |||
172 | for path in paths: | ||
173 | new_path = path["new"] | ||
174 | |||
175 | # These have already been assigned | ||
176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
177 | continue | ||
178 | |||
179 | # Global renaming happens here | ||
180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
183 | |||
184 | if additional_replacements is not None: | ||
185 | for replacement in additional_replacements: | ||
186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
187 | |||
188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
189 | if "proj_attn.weight" in new_path: | ||
190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
191 | else: | ||
192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
193 | |||
194 | |||
195 | def conv_attn_to_linear(checkpoint): | ||
196 | keys = list(checkpoint.keys()) | ||
197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
198 | for key in keys: | ||
199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
200 | if checkpoint[key].ndim > 2: | ||
201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
202 | elif "proj_attn.weight" in key: | ||
203 | if checkpoint[key].ndim > 2: | ||
204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
205 | |||
206 | |||
207 | def create_unet_diffusers_config(original_config): | ||
208 | """ | ||
209 | Creates a config for the diffusers based on the config of the LDM model. | ||
210 | """ | ||
211 | unet_params = original_config.model.params.unet_config.params | ||
212 | |||
213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
214 | |||
215 | down_block_types = [] | ||
216 | resolution = 1 | ||
217 | for i in range(len(block_out_channels)): | ||
218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
219 | down_block_types.append(block_type) | ||
220 | if i != len(block_out_channels) - 1: | ||
221 | resolution *= 2 | ||
222 | |||
223 | up_block_types = [] | ||
224 | for i in range(len(block_out_channels)): | ||
225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
226 | up_block_types.append(block_type) | ||
227 | resolution //= 2 | ||
228 | |||
229 | config = dict( | ||
230 | sample_size=unet_params.image_size, | ||
231 | in_channels=unet_params.in_channels, | ||
232 | out_channels=unet_params.out_channels, | ||
233 | down_block_types=tuple(down_block_types), | ||
234 | up_block_types=tuple(up_block_types), | ||
235 | block_out_channels=tuple(block_out_channels), | ||
236 | layers_per_block=unet_params.num_res_blocks, | ||
237 | cross_attention_dim=unet_params.context_dim, | ||
238 | attention_head_dim=unet_params.num_heads, | ||
239 | ) | ||
240 | |||
241 | return config | ||
242 | |||
243 | |||
244 | def create_vae_diffusers_config(original_config): | ||
245 | """ | ||
246 | Creates a config for the diffusers based on the config of the LDM model. | ||
247 | """ | ||
248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
250 | |||
251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
254 | |||
255 | config = dict( | ||
256 | sample_size=vae_params.resolution, | ||
257 | in_channels=vae_params.in_channels, | ||
258 | out_channels=vae_params.out_ch, | ||
259 | down_block_types=tuple(down_block_types), | ||
260 | up_block_types=tuple(up_block_types), | ||
261 | block_out_channels=tuple(block_out_channels), | ||
262 | latent_channels=vae_params.z_channels, | ||
263 | layers_per_block=vae_params.num_res_blocks, | ||
264 | ) | ||
265 | return config | ||
266 | |||
267 | |||
268 | def create_diffusers_schedular(original_config): | ||
269 | schedular = DDIMScheduler( | ||
270 | num_train_timesteps=original_config.model.params.timesteps, | ||
271 | beta_start=original_config.model.params.linear_start, | ||
272 | beta_end=original_config.model.params.linear_end, | ||
273 | beta_schedule="scaled_linear", | ||
274 | ) | ||
275 | return schedular | ||
276 | |||
277 | |||
278 | def create_ldm_bert_config(original_config): | ||
279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
280 | config = LDMBertConfig( | ||
281 | d_model=bert_params.n_embed, | ||
282 | encoder_layers=bert_params.n_layer, | ||
283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
284 | ) | ||
285 | return config | ||
286 | |||
287 | |||
288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
289 | """ | ||
290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
291 | """ | ||
292 | |||
293 | # extract state_dict for UNet | ||
294 | unet_state_dict = {} | ||
295 | unet_key = "model.diffusion_model." | ||
296 | keys = list(checkpoint.keys()) | ||
297 | for key in keys: | ||
298 | if key.startswith(unet_key): | ||
299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
300 | |||
301 | new_checkpoint = {} | ||
302 | |||
303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
307 | |||
308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
310 | |||
311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
315 | |||
316 | # Retrieves the keys for the input blocks only | ||
317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
318 | input_blocks = { | ||
319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
320 | for layer_id in range(num_input_blocks) | ||
321 | } | ||
322 | |||
323 | # Retrieves the keys for the middle blocks only | ||
324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
325 | middle_blocks = { | ||
326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
327 | for layer_id in range(num_middle_blocks) | ||
328 | } | ||
329 | |||
330 | # Retrieves the keys for the output blocks only | ||
331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
332 | output_blocks = { | ||
333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
334 | for layer_id in range(num_output_blocks) | ||
335 | } | ||
336 | |||
337 | for i in range(1, num_input_blocks): | ||
338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
340 | |||
341 | resnets = [ | ||
342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
343 | ] | ||
344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
345 | |||
346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
348 | f"input_blocks.{i}.0.op.weight" | ||
349 | ) | ||
350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
351 | f"input_blocks.{i}.0.op.bias" | ||
352 | ) | ||
353 | |||
354 | paths = renew_resnet_paths(resnets) | ||
355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
356 | assign_to_checkpoint( | ||
357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
358 | ) | ||
359 | |||
360 | if len(attentions): | ||
361 | paths = renew_attention_paths(attentions) | ||
362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
363 | assign_to_checkpoint( | ||
364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
365 | ) | ||
366 | |||
367 | resnet_0 = middle_blocks[0] | ||
368 | attentions = middle_blocks[1] | ||
369 | resnet_1 = middle_blocks[2] | ||
370 | |||
371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
373 | |||
374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
376 | |||
377 | attentions_paths = renew_attention_paths(attentions) | ||
378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
379 | assign_to_checkpoint( | ||
380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
381 | ) | ||
382 | |||
383 | for i in range(num_output_blocks): | ||
384 | block_id = i // (config["layers_per_block"] + 1) | ||
385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
387 | output_block_list = {} | ||
388 | |||
389 | for layer in output_block_layers: | ||
390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
391 | if layer_id in output_block_list: | ||
392 | output_block_list[layer_id].append(layer_name) | ||
393 | else: | ||
394 | output_block_list[layer_id] = [layer_name] | ||
395 | |||
396 | if len(output_block_list) > 1: | ||
397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
399 | |||
400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
401 | paths = renew_resnet_paths(resnets) | ||
402 | |||
403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
404 | assign_to_checkpoint( | ||
405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
406 | ) | ||
407 | |||
408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
411 | f"output_blocks.{i}.{index}.conv.weight" | ||
412 | ] | ||
413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
414 | f"output_blocks.{i}.{index}.conv.bias" | ||
415 | ] | ||
416 | |||
417 | # Clear attentions as they have been attributed above. | ||
418 | if len(attentions) == 2: | ||
419 | attentions = [] | ||
420 | |||
421 | if len(attentions): | ||
422 | paths = renew_attention_paths(attentions) | ||
423 | meta_path = { | ||
424 | "old": f"output_blocks.{i}.1", | ||
425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
426 | } | ||
427 | assign_to_checkpoint( | ||
428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
429 | ) | ||
430 | else: | ||
431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
432 | for path in resnet_0_paths: | ||
433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
435 | |||
436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
437 | |||
438 | return new_checkpoint | ||
439 | |||
440 | |||
441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
442 | # extract state dict for VAE | ||
443 | vae_state_dict = {} | ||
444 | vae_key = "first_stage_model." | ||
445 | keys = list(checkpoint.keys()) | ||
446 | for key in keys: | ||
447 | if key.startswith(vae_key): | ||
448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
449 | |||
450 | new_checkpoint = {} | ||
451 | |||
452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
458 | |||
459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
465 | |||
466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
470 | |||
471 | # Retrieves the keys for the encoder down blocks only | ||
472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
473 | down_blocks = { | ||
474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
475 | } | ||
476 | |||
477 | # Retrieves the keys for the decoder up blocks only | ||
478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
479 | up_blocks = { | ||
480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
481 | } | ||
482 | |||
483 | for i in range(num_down_blocks): | ||
484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
485 | |||
486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
488 | f"encoder.down.{i}.downsample.conv.weight" | ||
489 | ) | ||
490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
491 | f"encoder.down.{i}.downsample.conv.bias" | ||
492 | ) | ||
493 | |||
494 | paths = renew_vae_resnet_paths(resnets) | ||
495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
497 | |||
498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
499 | num_mid_res_blocks = 2 | ||
500 | for i in range(1, num_mid_res_blocks + 1): | ||
501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
502 | |||
503 | paths = renew_vae_resnet_paths(resnets) | ||
504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
506 | |||
507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
508 | paths = renew_vae_attention_paths(mid_attentions) | ||
509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
511 | conv_attn_to_linear(new_checkpoint) | ||
512 | |||
513 | for i in range(num_up_blocks): | ||
514 | block_id = num_up_blocks - 1 - i | ||
515 | resnets = [ | ||
516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
517 | ] | ||
518 | |||
519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
522 | ] | ||
523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
525 | ] | ||
526 | |||
527 | paths = renew_vae_resnet_paths(resnets) | ||
528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
530 | |||
531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
532 | num_mid_res_blocks = 2 | ||
533 | for i in range(1, num_mid_res_blocks + 1): | ||
534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
535 | |||
536 | paths = renew_vae_resnet_paths(resnets) | ||
537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
539 | |||
540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
541 | paths = renew_vae_attention_paths(mid_attentions) | ||
542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
544 | conv_attn_to_linear(new_checkpoint) | ||
545 | return new_checkpoint | ||
546 | |||
547 | |||
548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
553 | |||
554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
556 | |||
557 | def _copy_linear(hf_linear, pt_linear): | ||
558 | hf_linear.weight = pt_linear.weight | ||
559 | hf_linear.bias = pt_linear.bias | ||
560 | |||
561 | def _copy_layer(hf_layer, pt_layer): | ||
562 | # copy layer norms | ||
563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
565 | |||
566 | # copy attn | ||
567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
568 | |||
569 | # copy MLP | ||
570 | pt_mlp = pt_layer[1][1] | ||
571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
573 | |||
574 | def _copy_layers(hf_layers, pt_layers): | ||
575 | for i, hf_layer in enumerate(hf_layers): | ||
576 | if i != 0: | ||
577 | i += i | ||
578 | pt_layer = pt_layers[i : i + 2] | ||
579 | _copy_layer(hf_layer, pt_layer) | ||
580 | |||
581 | hf_model = LDMBertModel(config).eval() | ||
582 | |||
583 | # copy embeds | ||
584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
586 | |||
587 | # copy layer norm | ||
588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
589 | |||
590 | # copy hidden layers | ||
591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
592 | |||
593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
594 | |||
595 | return hf_model | ||
596 | |||
597 | |||
598 | if __name__ == "__main__": | ||
599 | parser = argparse.ArgumentParser() | ||
600 | |||
601 | parser.add_argument( | ||
602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
603 | ) | ||
604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
605 | parser.add_argument( | ||
606 | "--original_config_file", | ||
607 | default=None, | ||
608 | type=str, | ||
609 | help="The YAML config file corresponding to the original architecture.", | ||
610 | ) | ||
611 | parser.add_argument( | ||
612 | "--scheduler_type", | ||
613 | default="pndm", | ||
614 | type=str, | ||
615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
616 | ) | ||
617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
618 | |||
619 | args = parser.parse_args() | ||
620 | |||
621 | if args.original_config_file is None: | ||
622 | os.system( | ||
623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
624 | ) | ||
625 | args.original_config_file = "./v1-inference.yaml" | ||
626 | |||
627 | original_config = OmegaConf.load(args.original_config_file) | ||
628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
629 | |||
630 | num_train_timesteps = original_config.model.params.timesteps | ||
631 | beta_start = original_config.model.params.linear_start | ||
632 | beta_end = original_config.model.params.linear_end | ||
633 | if args.scheduler_type == "pndm": | ||
634 | scheduler = PNDMScheduler( | ||
635 | beta_end=beta_end, | ||
636 | beta_schedule="scaled_linear", | ||
637 | beta_start=beta_start, | ||
638 | num_train_timesteps=num_train_timesteps, | ||
639 | skip_prk_steps=True, | ||
640 | ) | ||
641 | elif args.scheduler_type == "lms": | ||
642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
643 | elif args.scheduler_type == "ddim": | ||
644 | scheduler = DDIMScheduler( | ||
645 | beta_start=beta_start, | ||
646 | beta_end=beta_end, | ||
647 | beta_schedule="scaled_linear", | ||
648 | clip_sample=False, | ||
649 | set_alpha_to_one=False, | ||
650 | ) | ||
651 | else: | ||
652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
653 | |||
654 | # Convert the UNet2DConditionModel model. | ||
655 | unet_config = create_unet_diffusers_config(original_config) | ||
656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
657 | |||
658 | unet = UNet2DConditionModel(**unet_config) | ||
659 | unet.load_state_dict(converted_unet_checkpoint) | ||
660 | |||
661 | # Convert the VAE model. | ||
662 | vae_config = create_vae_diffusers_config(original_config) | ||
663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
664 | |||
665 | vae = AutoencoderKL(**vae_config) | ||
666 | vae.load_state_dict(converted_vae_checkpoint) | ||
667 | |||
668 | # Convert the text model. | ||
669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
670 | if text_model_type == "FrozenCLIPEmbedder": | ||
671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
675 | pipe = StableDiffusionPipeline( | ||
676 | vae=vae, | ||
677 | text_encoder=text_model, | ||
678 | tokenizer=tokenizer, | ||
679 | unet=unet, | ||
680 | scheduler=scheduler, | ||
681 | safety_checker=safety_checker, | ||
682 | feature_extractor=feature_extractor, | ||
683 | ) | ||
684 | else: | ||
685 | text_config = create_ldm_bert_config(original_config) | ||
686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
689 | |||
690 | pipe.save_pretrained(args.dump_path) | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1b8a3d2..7a33bca 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -110,7 +110,7 @@ def parse_args(): | |||
110 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--tag_dropout", | 111 | "--tag_dropout", |
112 | type=float, | 112 | type=float, |
113 | default=0.1, | 113 | default=0, |
114 | help="Tag dropout probability.", | 114 | help="Tag dropout probability.", |
115 | ) | 115 | ) |
116 | parser.add_argument( | 116 | parser.add_argument( |
@@ -131,6 +131,11 @@ def parse_args(): | |||
131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
132 | ) | 132 | ) |
133 | parser.add_argument( | 133 | parser.add_argument( |
134 | "--guidance_scale", | ||
135 | type=float, | ||
136 | default=0, | ||
137 | ) | ||
138 | parser.add_argument( | ||
134 | "--num_class_images", | 139 | "--num_class_images", |
135 | type=int, | 140 | type=int, |
136 | default=0, | 141 | default=0, |
@@ -178,7 +183,7 @@ def parse_args(): | |||
178 | parser.add_argument( | 183 | parser.add_argument( |
179 | "--offset_noise_strength", | 184 | "--offset_noise_strength", |
180 | type=float, | 185 | type=float, |
181 | default=0.15, | 186 | default=0, |
182 | help="Perlin offset noise strength.", | 187 | help="Perlin offset noise strength.", |
183 | ) | 188 | ) |
184 | parser.add_argument( | 189 | parser.add_argument( |
@@ -557,8 +562,8 @@ def main(): | |||
557 | vae=vae, | 562 | vae=vae, |
558 | noise_scheduler=noise_scheduler, | 563 | noise_scheduler=noise_scheduler, |
559 | dtype=weight_dtype, | 564 | dtype=weight_dtype, |
560 | with_prior_preservation=args.num_class_images != 0, | 565 | guidance_scale=args.guidance_scale, |
561 | prior_loss_weight=args.prior_loss_weight, | 566 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
562 | no_val=args.valid_set_size == 0, | 567 | no_val=args.valid_set_size == 0, |
563 | ) | 568 | ) |
564 | 569 | ||
@@ -570,6 +575,7 @@ def main(): | |||
570 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
571 | tokenizer=tokenizer, | 576 | tokenizer=tokenizer, |
572 | class_subdir=args.class_image_dir, | 577 | class_subdir=args.class_image_dir, |
578 | with_guidance=args.guidance_scale != 0, | ||
573 | num_class_images=args.num_class_images, | 579 | num_class_images=args.num_class_images, |
574 | size=args.resolution, | 580 | size=args.resolution, |
575 | num_buckets=args.num_buckets, | 581 | num_buckets=args.num_buckets, |
diff --git a/train_lora.py b/train_lora.py index b16a99b..684d0cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -88,7 +88,7 @@ def parse_args(): | |||
88 | parser.add_argument( | 88 | parser.add_argument( |
89 | "--num_buckets", | 89 | "--num_buckets", |
90 | type=int, | 90 | type=int, |
91 | default=0, | 91 | default=2, |
92 | help="Number of aspect ratio buckets in either direction.", | 92 | help="Number of aspect ratio buckets in either direction.", |
93 | ) | 93 | ) |
94 | parser.add_argument( | 94 | parser.add_argument( |
@@ -111,7 +111,7 @@ def parse_args(): | |||
111 | parser.add_argument( | 111 | parser.add_argument( |
112 | "--tag_dropout", | 112 | "--tag_dropout", |
113 | type=float, | 113 | type=float, |
114 | default=0.1, | 114 | default=0, |
115 | help="Tag dropout probability.", | 115 | help="Tag dropout probability.", |
116 | ) | 116 | ) |
117 | parser.add_argument( | 117 | parser.add_argument( |
@@ -120,6 +120,11 @@ def parse_args(): | |||
120 | help="Shuffle tags.", | 120 | help="Shuffle tags.", |
121 | ) | 121 | ) |
122 | parser.add_argument( | 122 | parser.add_argument( |
123 | "--guidance_scale", | ||
124 | type=float, | ||
125 | default=0, | ||
126 | ) | ||
127 | parser.add_argument( | ||
123 | "--num_class_images", | 128 | "--num_class_images", |
124 | type=int, | 129 | type=int, |
125 | default=0, | 130 | default=0, |
@@ -167,7 +172,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 172 | parser.add_argument( |
168 | "--offset_noise_strength", | 173 | "--offset_noise_strength", |
169 | type=float, | 174 | type=float, |
170 | default=0.15, | 175 | default=0, |
171 | help="Perlin offset noise strength.", | 176 | help="Perlin offset noise strength.", |
172 | ) | 177 | ) |
173 | parser.add_argument( | 178 | parser.add_argument( |
@@ -589,8 +594,8 @@ def main(): | |||
589 | vae=vae, | 594 | vae=vae, |
590 | noise_scheduler=noise_scheduler, | 595 | noise_scheduler=noise_scheduler, |
591 | dtype=weight_dtype, | 596 | dtype=weight_dtype, |
592 | with_prior_preservation=args.num_class_images != 0, | 597 | guidance_scale=args.guidance_scale, |
593 | prior_loss_weight=args.prior_loss_weight, | 598 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
594 | no_val=args.valid_set_size == 0, | 599 | no_val=args.valid_set_size == 0, |
595 | ) | 600 | ) |
596 | 601 | ||
@@ -602,6 +607,7 @@ def main(): | |||
602 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
603 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
604 | class_subdir=args.class_image_dir, | 609 | class_subdir=args.class_image_dir, |
610 | with_guidance=args.guidance_scale != 0, | ||
605 | num_class_images=args.num_class_images, | 611 | num_class_images=args.num_class_images, |
606 | size=args.resolution, | 612 | size=args.resolution, |
607 | num_buckets=args.num_buckets, | 613 | num_buckets=args.num_buckets, |
diff --git a/train_ti.py b/train_ti.py index bbc5524..83ad46d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -91,6 +91,11 @@ def parse_args(): | |||
91 | action="store_true", | 91 | action="store_true", |
92 | ) | 92 | ) |
93 | parser.add_argument( | 93 | parser.add_argument( |
94 | "--guidance_scale", | ||
95 | type=float, | ||
96 | default=0, | ||
97 | ) | ||
98 | parser.add_argument( | ||
94 | "--num_class_images", | 99 | "--num_class_images", |
95 | type=int, | 100 | type=int, |
96 | default=0, | 101 | default=0, |
@@ -167,7 +172,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 172 | parser.add_argument( |
168 | "--tag_dropout", | 173 | "--tag_dropout", |
169 | type=float, | 174 | type=float, |
170 | default=0.1, | 175 | default=0, |
171 | help="Tag dropout probability.", | 176 | help="Tag dropout probability.", |
172 | ) | 177 | ) |
173 | parser.add_argument( | 178 | parser.add_argument( |
@@ -190,7 +195,7 @@ def parse_args(): | |||
190 | parser.add_argument( | 195 | parser.add_argument( |
191 | "--offset_noise_strength", | 196 | "--offset_noise_strength", |
192 | type=float, | 197 | type=float, |
193 | default=0.15, | 198 | default=0, |
194 | help="Perlin offset noise strength.", | 199 | help="Perlin offset noise strength.", |
195 | ) | 200 | ) |
196 | parser.add_argument( | 201 | parser.add_argument( |
@@ -651,8 +656,8 @@ def main(): | |||
651 | noise_scheduler=noise_scheduler, | 656 | noise_scheduler=noise_scheduler, |
652 | dtype=weight_dtype, | 657 | dtype=weight_dtype, |
653 | seed=args.seed, | 658 | seed=args.seed, |
654 | with_prior_preservation=args.num_class_images != 0, | 659 | guidance_scale=args.guidance_scale, |
655 | prior_loss_weight=args.prior_loss_weight, | 660 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
656 | no_val=args.valid_set_size == 0, | 661 | no_val=args.valid_set_size == 0, |
657 | strategy=textual_inversion_strategy, | 662 | strategy=textual_inversion_strategy, |
658 | num_train_epochs=args.num_train_epochs, | 663 | num_train_epochs=args.num_train_epochs, |
@@ -705,6 +710,7 @@ def main(): | |||
705 | batch_size=args.train_batch_size, | 710 | batch_size=args.train_batch_size, |
706 | tokenizer=tokenizer, | 711 | tokenizer=tokenizer, |
707 | class_subdir=args.class_image_dir, | 712 | class_subdir=args.class_image_dir, |
713 | with_guidance=args.guidance_scale != 0, | ||
708 | num_class_images=args.num_class_images, | 714 | num_class_images=args.num_class_images, |
709 | size=args.resolution, | 715 | size=args.resolution, |
710 | num_buckets=args.num_buckets, | 716 | num_buckets=args.num_buckets, |
diff --git a/training/functional.py b/training/functional.py index 87bb339..d285366 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -274,7 +274,7 @@ def loss_step( | |||
274 | noise_scheduler: SchedulerMixin, | 274 | noise_scheduler: SchedulerMixin, |
275 | unet: UNet2DConditionModel, | 275 | unet: UNet2DConditionModel, |
276 | text_encoder: CLIPTextModel, | 276 | text_encoder: CLIPTextModel, |
277 | with_prior_preservation: bool, | 277 | guidance_scale: float, |
278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
279 | seed: int, | 279 | seed: int, |
280 | offset_noise_strength: float, | 280 | offset_noise_strength: float, |
@@ -283,13 +283,13 @@ def loss_step( | |||
283 | eval: bool = False, | 283 | eval: bool = False, |
284 | min_snr_gamma: int = 5, | 284 | min_snr_gamma: int = 5, |
285 | ): | 285 | ): |
286 | # Convert images to latent space | 286 | images = batch["pixel_values"] |
287 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 287 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
288 | latents = latents * vae.config.scaling_factor | 288 | bsz = images.shape[0] |
289 | |||
290 | bsz = latents.shape[0] | ||
291 | 289 | ||
292 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 290 | # Convert images to latent space |
291 | latents = vae.encode(images).latent_dist.sample(generator=generator) | ||
292 | latents *= vae.config.scaling_factor | ||
293 | 293 | ||
294 | # Sample noise that we'll add to the latents | 294 | # Sample noise that we'll add to the latents |
295 | noise = torch.randn( | 295 | noise = torch.randn( |
@@ -301,13 +301,13 @@ def loss_step( | |||
301 | ) | 301 | ) |
302 | 302 | ||
303 | if offset_noise_strength != 0: | 303 | if offset_noise_strength != 0: |
304 | noise += offset_noise_strength * perlin_noise( | 304 | offset_noise = torch.randn( |
305 | latents.shape, | 305 | (latents.shape[0], latents.shape[1], 1, 1), |
306 | res=1, | ||
307 | dtype=latents.dtype, | 306 | dtype=latents.dtype, |
308 | device=latents.device, | 307 | device=latents.device, |
309 | generator=generator | 308 | generator=generator |
310 | ) | 309 | ).expand(noise.shape) |
310 | noise += offset_noise_strength * offset_noise | ||
311 | 311 | ||
312 | # Sample a random timestep for each image | 312 | # Sample a random timestep for each image |
313 | timesteps = torch.randint( | 313 | timesteps = torch.randint( |
@@ -343,7 +343,13 @@ def loss_step( | |||
343 | else: | 343 | else: |
344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
345 | 345 | ||
346 | if with_prior_preservation: | 346 | if guidance_scale != 0: |
347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | ||
349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | ||
350 | |||
351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
352 | elif prior_loss_weight != 0: | ||
347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 353 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
348 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 354 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
349 | target, target_prior = torch.chunk(target, 2, dim=0) | 355 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -607,9 +613,9 @@ def train( | |||
607 | checkpoint_frequency: int = 50, | 613 | checkpoint_frequency: int = 50, |
608 | milestone_checkpoints: bool = True, | 614 | milestone_checkpoints: bool = True, |
609 | global_step_offset: int = 0, | 615 | global_step_offset: int = 0, |
610 | with_prior_preservation: bool = False, | 616 | guidance_scale: float = 0.0, |
611 | prior_loss_weight: float = 1.0, | 617 | prior_loss_weight: float = 1.0, |
612 | offset_noise_strength: float = 0.1, | 618 | offset_noise_strength: float = 0.15, |
613 | **kwargs, | 619 | **kwargs, |
614 | ): | 620 | ): |
615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 621 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -638,7 +644,7 @@ def train( | |||
638 | noise_scheduler, | 644 | noise_scheduler, |
639 | unet, | 645 | unet, |
640 | text_encoder, | 646 | text_encoder, |
641 | with_prior_preservation, | 647 | guidance_scale, |
642 | prior_loss_weight, | 648 | prior_loss_weight, |
643 | seed, | 649 | seed, |
644 | offset_noise_strength, | 650 | offset_noise_strength, |