diff options
| -rw-r--r-- | data/csv.py | 21 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 61 | ||||
| -rw-r--r-- | scripts/convert_diffusers_to_original_stable_diffusion.py | 234 | ||||
| -rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 | ||||
| -rw-r--r-- | train_dreambooth.py | 14 | ||||
| -rw-r--r-- | train_lora.py | 16 | ||||
| -rw-r--r-- | train_ti.py | 14 | ||||
| -rw-r--r-- | training/functional.py | 36 |
8 files changed, 99 insertions, 987 deletions
diff --git a/data/csv.py b/data/csv.py index fba5d4b..a6cd065 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -99,14 +99,16 @@ def generate_buckets( | |||
| 99 | return buckets, bucket_items, bucket_assignments | 99 | return buckets, bucket_items, bucket_assignments |
| 100 | 100 | ||
| 101 | 101 | ||
| 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): | 102 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): |
| 103 | prompt_ids = [example["prompt_ids"] for example in examples] | 103 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 105 | 105 | ||
| 106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 107 | pixel_values = [example["instance_images"] for example in examples] | 107 | pixel_values = [example["instance_images"] for example in examples] |
| 108 | 108 | ||
| 109 | if with_prior_preservation: | 109 | if with_guidance: |
| 110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
| 111 | elif with_prior_preservation: | ||
| 110 | input_ids += [example["class_prompt_ids"] for example in examples] | 112 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 111 | pixel_values += [example["class_images"] for example in examples] | 113 | pixel_values += [example["class_images"] for example in examples] |
| 112 | 114 | ||
| @@ -133,7 +135,7 @@ class VlpnDataItem(NamedTuple): | |||
| 133 | class_image_path: Path | 135 | class_image_path: Path |
| 134 | prompt: list[str] | 136 | prompt: list[str] |
| 135 | cprompt: str | 137 | cprompt: str |
| 136 | nprompt: str | 138 | nprompt: list[str] |
| 137 | collection: list[str] | 139 | collection: list[str] |
| 138 | 140 | ||
| 139 | 141 | ||
| @@ -163,6 +165,7 @@ class VlpnDataModule(): | |||
| 163 | data_file: str, | 165 | data_file: str, |
| 164 | tokenizer: CLIPTokenizer, | 166 | tokenizer: CLIPTokenizer, |
| 165 | class_subdir: str = "cls", | 167 | class_subdir: str = "cls", |
| 168 | with_guidance: bool = False, | ||
| 166 | num_class_images: int = 1, | 169 | num_class_images: int = 1, |
| 167 | size: int = 768, | 170 | size: int = 768, |
| 168 | num_buckets: int = 0, | 171 | num_buckets: int = 0, |
| @@ -191,6 +194,7 @@ class VlpnDataModule(): | |||
| 191 | self.class_root = self.data_root / class_subdir | 194 | self.class_root = self.data_root / class_subdir |
| 192 | self.class_root.mkdir(parents=True, exist_ok=True) | 195 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 193 | self.num_class_images = num_class_images | 196 | self.num_class_images = num_class_images |
| 197 | self.with_guidance = with_guidance | ||
| 194 | 198 | ||
| 195 | self.tokenizer = tokenizer | 199 | self.tokenizer = tokenizer |
| 196 | self.size = size | 200 | self.size = size |
| @@ -228,10 +232,10 @@ class VlpnDataModule(): | |||
| 228 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 229 | expansions | 233 | expansions |
| 230 | )), | 234 | )), |
| 231 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_to_keywords( |
| 232 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 236 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 233 | expansions | 237 | expansions |
| 234 | )), | 238 | ), |
| 235 | item["collection"].split(", ") if "collection" in item else [] | 239 | item["collection"].split(", ") if "collection" in item else [] |
| 236 | ) | 240 | ) |
| 237 | for item in data | 241 | for item in data |
| @@ -279,7 +283,7 @@ class VlpnDataModule(): | |||
| 279 | if self.seed is not None: | 283 | if self.seed is not None: |
| 280 | generator = generator.manual_seed(self.seed) | 284 | generator = generator.manual_seed(self.seed) |
| 281 | 285 | ||
| 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 286 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
| 283 | 287 | ||
| 284 | if valid_set_size == 0: | 288 | if valid_set_size == 0: |
| 285 | data_train, data_val = items, items | 289 | data_train, data_val = items, items |
| @@ -443,11 +447,14 @@ class VlpnDataset(IterableDataset): | |||
| 443 | example = {} | 447 | example = {} |
| 444 | 448 | ||
| 445 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) | 449 | example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) |
| 446 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 450 | example["nprompt_ids"] = self.get_input_ids(keywords_to_prompt(item.nprompt)) |
| 447 | 451 | ||
| 448 | example["instance_prompt_ids"] = self.get_input_ids( | 452 | example["instance_prompt_ids"] = self.get_input_ids( |
| 449 | keywords_to_prompt(item.prompt, self.dropout, True) | 453 | keywords_to_prompt(item.prompt, self.dropout, True) |
| 450 | ) | 454 | ) |
| 455 | example["negative_prompt_ids"] = self.get_input_ids( | ||
| 456 | keywords_to_prompt(item.nprompt, self.dropout, True) | ||
| 457 | ) | ||
| 451 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 458 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 452 | 459 | ||
| 453 | if self.num_class_images != 0: | 460 | if self.num_class_images != 0: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index ea2a656..127ca50 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -307,39 +307,45 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 307 | 307 | ||
| 308 | return timesteps, num_inference_steps - t_start | 308 | return timesteps, num_inference_steps - t_start |
| 309 | 309 | ||
| 310 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): |
| 311 | return (1.4 * perlin_noise( | 311 | offset_image = perlin_noise( |
| 312 | (batch_size, 1, width, height), | 312 | (batch_size, 1, width, height), |
| 313 | res=1, | 313 | res=1, |
| 314 | octaves=4, | ||
| 315 | generator=generator, | 314 | generator=generator, |
| 316 | dtype=dtype, | 315 | dtype=dtype, |
| 317 | device=device | 316 | device=device |
| 318 | )).clamp(-1, 1).expand(batch_size, 3, width, height) | 317 | ) |
| 318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | ||
| 319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
| 320 | return offset_latents | ||
| 319 | 321 | ||
| 320 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): |
| 321 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
| 322 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
| 323 | init_latents = self.vae.config.scaling_factor * init_latents | 325 | latents = self.vae.config.scaling_factor * latents |
| 324 | 326 | ||
| 325 | if batch_size % init_latents.shape[0] != 0: | 327 | if batch_size % latents.shape[0] != 0: |
| 326 | raise ValueError( | 328 | raise ValueError( |
| 327 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 329 | f"Cannot duplicate `init_image` of batch size {latents.shape[0]} to {batch_size} text prompts." |
| 328 | ) | 330 | ) |
| 329 | else: | 331 | else: |
| 330 | batch_multiplier = batch_size // init_latents.shape[0] | 332 | batch_multiplier = batch_size // latents.shape[0] |
| 331 | init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) | 333 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
| 332 | 334 | ||
| 333 | # add noise to latents using the timesteps | 335 | # add noise to latents using the timesteps |
| 334 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) |
| 337 | |||
| 338 | if brightness_offset != 0: | ||
| 339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
| 340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
| 341 | ) | ||
| 335 | 342 | ||
| 336 | # get latents | 343 | # get latents |
| 337 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | 344 | latents = self.scheduler.add_noise(latents, noise, timestep) |
| 338 | latents = init_latents | ||
| 339 | 345 | ||
| 340 | return latents | 346 | return latents |
| 341 | 347 | ||
| 342 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | 348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): |
| 343 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| 344 | if isinstance(generator, list) and len(generator) != batch_size: | 350 | if isinstance(generator, list) and len(generator) != batch_size: |
| 345 | raise ValueError( | 351 | raise ValueError( |
| @@ -352,6 +358,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 352 | else: | 358 | else: |
| 353 | latents = latents.to(device) | 359 | latents = latents.to(device) |
| 354 | 360 | ||
| 361 | if brightness_offset != 0: | ||
| 362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
| 363 | batch_size, height, width, dtype, device, generator | ||
| 364 | ) | ||
| 365 | |||
| 355 | # scale the initial noise by the standard deviation required by the scheduler | 366 | # scale the initial noise by the standard deviation required by the scheduler |
| 356 | latents = latents * self.scheduler.init_noise_sigma | 367 | latents = latents * self.scheduler.init_noise_sigma |
| 357 | return latents | 368 | return latents |
| @@ -395,7 +406,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 395 | sag_scale: float = 0.75, | 406 | sag_scale: float = 0.75, |
| 396 | eta: float = 0.0, | 407 | eta: float = 0.0, |
| 397 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 398 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, | 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
| 399 | output_type: str = "pil", | 411 | output_type: str = "pil", |
| 400 | return_dict: bool = True, | 412 | return_dict: bool = True, |
| 401 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| @@ -468,7 +480,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 468 | num_channels_latents = self.unet.in_channels | 480 | num_channels_latents = self.unet.in_channels |
| 469 | do_classifier_free_guidance = guidance_scale > 1.0 | 481 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 470 | do_self_attention_guidance = sag_scale > 0.0 | 482 | do_self_attention_guidance = sag_scale > 0.0 |
| 471 | prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise" | 483 | prep_from_image = isinstance(image, PIL.Image.Image) |
| 472 | 484 | ||
| 473 | # 3. Encode input prompt | 485 | # 3. Encode input prompt |
| 474 | prompt_embeds = self.encode_prompt( | 486 | prompt_embeds = self.encode_prompt( |
| @@ -482,15 +494,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 482 | # 4. Prepare latent variables | 494 | # 4. Prepare latent variables |
| 483 | if isinstance(image, PIL.Image.Image): | 495 | if isinstance(image, PIL.Image.Image): |
| 484 | image = preprocess(image) | 496 | image = preprocess(image) |
| 485 | elif image == "noise": | ||
| 486 | image = self.prepare_image( | ||
| 487 | batch_size * num_images_per_prompt, | ||
| 488 | width, | ||
| 489 | height, | ||
| 490 | prompt_embeds.dtype, | ||
| 491 | device, | ||
| 492 | generator | ||
| 493 | ) | ||
| 494 | 497 | ||
| 495 | # 5. Prepare timesteps | 498 | # 5. Prepare timesteps |
| 496 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| @@ -503,9 +506,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 503 | image, | 506 | image, |
| 504 | latent_timestep, | 507 | latent_timestep, |
| 505 | batch_size * num_images_per_prompt, | 508 | batch_size * num_images_per_prompt, |
| 509 | brightness_offset, | ||
| 506 | prompt_embeds.dtype, | 510 | prompt_embeds.dtype, |
| 507 | device, | 511 | device, |
| 508 | generator | 512 | generator, |
| 509 | ) | 513 | ) |
| 510 | else: | 514 | else: |
| 511 | latents = self.prepare_latents( | 515 | latents = self.prepare_latents( |
| @@ -513,10 +517,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 513 | num_channels_latents, | 517 | num_channels_latents, |
| 514 | height, | 518 | height, |
| 515 | width, | 519 | width, |
| 520 | brightness_offset, | ||
| 516 | prompt_embeds.dtype, | 521 | prompt_embeds.dtype, |
| 517 | device, | 522 | device, |
| 518 | generator, | 523 | generator, |
| 519 | image | 524 | image, |
| 520 | ) | 525 | ) |
| 521 | 526 | ||
| 522 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 527 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py deleted file mode 100644 index 9888f62..0000000 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ /dev/null | |||
| @@ -1,234 +0,0 @@ | |||
| 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. | ||
| 2 | # *Only* converts the UNet, VAE, and Text Encoder. | ||
| 3 | # Does not convert optimizer state or any other thing. | ||
| 4 | |||
| 5 | import argparse | ||
| 6 | import os.path as osp | ||
| 7 | |||
| 8 | import torch | ||
| 9 | |||
| 10 | |||
| 11 | # =================# | ||
| 12 | # UNet Conversion # | ||
| 13 | # =================# | ||
| 14 | |||
| 15 | unet_conversion_map = [ | ||
| 16 | # (stable-diffusion, HF Diffusers) | ||
| 17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), | ||
| 18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), | ||
| 19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), | ||
| 20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), | ||
| 21 | ("input_blocks.0.0.weight", "conv_in.weight"), | ||
| 22 | ("input_blocks.0.0.bias", "conv_in.bias"), | ||
| 23 | ("out.0.weight", "conv_norm_out.weight"), | ||
| 24 | ("out.0.bias", "conv_norm_out.bias"), | ||
| 25 | ("out.2.weight", "conv_out.weight"), | ||
| 26 | ("out.2.bias", "conv_out.bias"), | ||
| 27 | ] | ||
| 28 | |||
| 29 | unet_conversion_map_resnet = [ | ||
| 30 | # (stable-diffusion, HF Diffusers) | ||
| 31 | ("in_layers.0", "norm1"), | ||
| 32 | ("in_layers.2", "conv1"), | ||
| 33 | ("out_layers.0", "norm2"), | ||
| 34 | ("out_layers.3", "conv2"), | ||
| 35 | ("emb_layers.1", "time_emb_proj"), | ||
| 36 | ("skip_connection", "conv_shortcut"), | ||
| 37 | ] | ||
| 38 | |||
| 39 | unet_conversion_map_layer = [] | ||
| 40 | # hardcoded number of downblocks and resnets/attentions... | ||
| 41 | # would need smarter logic for other networks. | ||
| 42 | for i in range(4): | ||
| 43 | # loop over downblocks/upblocks | ||
| 44 | |||
| 45 | for j in range(2): | ||
| 46 | # loop over resnets/attentions for downblocks | ||
| 47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." | ||
| 48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." | ||
| 49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) | ||
| 50 | |||
| 51 | if i < 3: | ||
| 52 | # no attention layers in down_blocks.3 | ||
| 53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." | ||
| 54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." | ||
| 55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) | ||
| 56 | |||
| 57 | for j in range(3): | ||
| 58 | # loop over resnets/attentions for upblocks | ||
| 59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." | ||
| 60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." | ||
| 61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) | ||
| 62 | |||
| 63 | if i > 0: | ||
| 64 | # no attention layers in up_blocks.0 | ||
| 65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." | ||
| 66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." | ||
| 67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) | ||
| 68 | |||
| 69 | if i < 3: | ||
| 70 | # no downsample in down_blocks.3 | ||
| 71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." | ||
| 72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." | ||
| 73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 74 | |||
| 75 | # no upsample in up_blocks.3 | ||
| 76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." | ||
| 78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 79 | |||
| 80 | hf_mid_atn_prefix = "mid_block.attentions.0." | ||
| 81 | sd_mid_atn_prefix = "middle_block.1." | ||
| 82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) | ||
| 83 | |||
| 84 | for j in range(2): | ||
| 85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." | ||
| 86 | sd_mid_res_prefix = f"middle_block.{2*j}." | ||
| 87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 88 | |||
| 89 | |||
| 90 | def convert_unet_state_dict(unet_state_dict): | ||
| 91 | # buyer beware: this is a *brittle* function, | ||
| 92 | # and correct output requires that all of these pieces interact in | ||
| 93 | # the exact order in which I have arranged them. | ||
| 94 | mapping = {k: k for k in unet_state_dict.keys()} | ||
| 95 | for sd_name, hf_name in unet_conversion_map: | ||
| 96 | mapping[hf_name] = sd_name | ||
| 97 | for k, v in mapping.items(): | ||
| 98 | if "resnets" in k: | ||
| 99 | for sd_part, hf_part in unet_conversion_map_resnet: | ||
| 100 | v = v.replace(hf_part, sd_part) | ||
| 101 | mapping[k] = v | ||
| 102 | for k, v in mapping.items(): | ||
| 103 | for sd_part, hf_part in unet_conversion_map_layer: | ||
| 104 | v = v.replace(hf_part, sd_part) | ||
| 105 | mapping[k] = v | ||
| 106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} | ||
| 107 | return new_state_dict | ||
| 108 | |||
| 109 | |||
| 110 | # ================# | ||
| 111 | # VAE Conversion # | ||
| 112 | # ================# | ||
| 113 | |||
| 114 | vae_conversion_map = [ | ||
| 115 | # (stable-diffusion, HF Diffusers) | ||
| 116 | ("nin_shortcut", "conv_shortcut"), | ||
| 117 | ("norm_out", "conv_norm_out"), | ||
| 118 | ("mid.attn_1.", "mid_block.attentions.0."), | ||
| 119 | ] | ||
| 120 | |||
| 121 | for i in range(4): | ||
| 122 | # down_blocks have two resnets | ||
| 123 | for j in range(2): | ||
| 124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." | ||
| 125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." | ||
| 126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) | ||
| 127 | |||
| 128 | if i < 3: | ||
| 129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." | ||
| 130 | sd_downsample_prefix = f"down.{i}.downsample." | ||
| 131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) | ||
| 132 | |||
| 133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." | ||
| 134 | sd_upsample_prefix = f"up.{3-i}.upsample." | ||
| 135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) | ||
| 136 | |||
| 137 | # up_blocks have three resnets | ||
| 138 | # also, up blocks in hf are numbered in reverse from sd | ||
| 139 | for j in range(3): | ||
| 140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." | ||
| 141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." | ||
| 142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) | ||
| 143 | |||
| 144 | # this part accounts for mid blocks in both the encoder and the decoder | ||
| 145 | for i in range(2): | ||
| 146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." | ||
| 147 | sd_mid_res_prefix = f"mid.block_{i+1}." | ||
| 148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) | ||
| 149 | |||
| 150 | |||
| 151 | vae_conversion_map_attn = [ | ||
| 152 | # (stable-diffusion, HF Diffusers) | ||
| 153 | ("norm.", "group_norm."), | ||
| 154 | ("q.", "query."), | ||
| 155 | ("k.", "key."), | ||
| 156 | ("v.", "value."), | ||
| 157 | ("proj_out.", "proj_attn."), | ||
| 158 | ] | ||
| 159 | |||
| 160 | |||
| 161 | def reshape_weight_for_sd(w): | ||
| 162 | # convert HF linear weights to SD conv2d weights | ||
| 163 | return w.reshape(*w.shape, 1, 1) | ||
| 164 | |||
| 165 | |||
| 166 | def convert_vae_state_dict(vae_state_dict): | ||
| 167 | mapping = {k: k for k in vae_state_dict.keys()} | ||
| 168 | for k, v in mapping.items(): | ||
| 169 | for sd_part, hf_part in vae_conversion_map: | ||
| 170 | v = v.replace(hf_part, sd_part) | ||
| 171 | mapping[k] = v | ||
| 172 | for k, v in mapping.items(): | ||
| 173 | if "attentions" in k: | ||
| 174 | for sd_part, hf_part in vae_conversion_map_attn: | ||
| 175 | v = v.replace(hf_part, sd_part) | ||
| 176 | mapping[k] = v | ||
| 177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} | ||
| 178 | weights_to_convert = ["q", "k", "v", "proj_out"] | ||
| 179 | for k, v in new_state_dict.items(): | ||
| 180 | for weight_name in weights_to_convert: | ||
| 181 | if f"mid.attn_1.{weight_name}.weight" in k: | ||
| 182 | print(f"Reshaping {k} for SD format") | ||
| 183 | new_state_dict[k] = reshape_weight_for_sd(v) | ||
| 184 | return new_state_dict | ||
| 185 | |||
| 186 | |||
| 187 | # =========================# | ||
| 188 | # Text Encoder Conversion # | ||
| 189 | # =========================# | ||
| 190 | # pretty much a no-op | ||
| 191 | |||
| 192 | |||
| 193 | def convert_text_enc_state_dict(text_enc_dict): | ||
| 194 | return text_enc_dict | ||
| 195 | |||
| 196 | |||
| 197 | if __name__ == "__main__": | ||
| 198 | parser = argparse.ArgumentParser() | ||
| 199 | |||
| 200 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") | ||
| 201 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") | ||
| 203 | |||
| 204 | args = parser.parse_args() | ||
| 205 | |||
| 206 | assert args.model_path is not None, "Must provide a model path!" | ||
| 207 | |||
| 208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" | ||
| 209 | |||
| 210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") | ||
| 211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") | ||
| 212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") | ||
| 213 | |||
| 214 | # Convert the UNet model | ||
| 215 | unet_state_dict = torch.load(unet_path, map_location="cpu") | ||
| 216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) | ||
| 217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} | ||
| 218 | |||
| 219 | # Convert the VAE model | ||
| 220 | vae_state_dict = torch.load(vae_path, map_location="cpu") | ||
| 221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) | ||
| 222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} | ||
| 223 | |||
| 224 | # Convert the text encoder model | ||
| 225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") | ||
| 226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) | ||
| 227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} | ||
| 228 | |||
| 229 | # Put together new checkpoint | ||
| 230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} | ||
| 231 | if args.half: | ||
| 232 | state_dict = {k: v.half() for k, v in state_dict.items()} | ||
| 233 | state_dict = {"state_dict": state_dict} | ||
| 234 | torch.save(state_dict, args.checkpoint_path) | ||
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py deleted file mode 100644 index ee7fc33..0000000 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ /dev/null | |||
| @@ -1,690 +0,0 @@ | |||
| 1 | # coding=utf-8 | ||
| 2 | # Copyright 2022 The HuggingFace Inc. team. | ||
| 3 | # | ||
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | # you may not use this file except in compliance with the License. | ||
| 6 | # You may obtain a copy of the License at | ||
| 7 | # | ||
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | # | ||
| 10 | # Unless required by applicable law or agreed to in writing, software | ||
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | # See the License for the specific language governing permissions and | ||
| 14 | # limitations under the License. | ||
| 15 | """ Conversion script for the LDM checkpoints. """ | ||
| 16 | |||
| 17 | import argparse | ||
| 18 | import os | ||
| 19 | |||
| 20 | import torch | ||
| 21 | |||
| 22 | |||
| 23 | try: | ||
| 24 | from omegaconf import OmegaConf | ||
| 25 | except ImportError: | ||
| 26 | raise ImportError( | ||
| 27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
| 28 | ) | ||
| 29 | |||
| 30 | from diffusers import ( | ||
| 31 | AutoencoderKL, | ||
| 32 | DDIMScheduler, | ||
| 33 | LDMTextToImagePipeline, | ||
| 34 | LMSDiscreteScheduler, | ||
| 35 | PNDMScheduler, | ||
| 36 | StableDiffusionPipeline, | ||
| 37 | UNet2DConditionModel, | ||
| 38 | ) | ||
| 39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
| 40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
| 41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
| 42 | |||
| 43 | |||
| 44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
| 45 | """ | ||
| 46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
| 47 | """ | ||
| 48 | if n_shave_prefix_segments >= 0: | ||
| 49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
| 50 | else: | ||
| 51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
| 52 | |||
| 53 | |||
| 54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 55 | """ | ||
| 56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 57 | """ | ||
| 58 | mapping = [] | ||
| 59 | for old_item in old_list: | ||
| 60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
| 61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
| 62 | |||
| 63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
| 64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
| 65 | |||
| 66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
| 67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
| 68 | |||
| 69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 70 | |||
| 71 | mapping.append({"old": old_item, "new": new_item}) | ||
| 72 | |||
| 73 | return mapping | ||
| 74 | |||
| 75 | |||
| 76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 77 | """ | ||
| 78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 79 | """ | ||
| 80 | mapping = [] | ||
| 81 | for old_item in old_list: | ||
| 82 | new_item = old_item | ||
| 83 | |||
| 84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
| 85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 86 | |||
| 87 | mapping.append({"old": old_item, "new": new_item}) | ||
| 88 | |||
| 89 | return mapping | ||
| 90 | |||
| 91 | |||
| 92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 93 | """ | ||
| 94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 95 | """ | ||
| 96 | mapping = [] | ||
| 97 | for old_item in old_list: | ||
| 98 | new_item = old_item | ||
| 99 | |||
| 100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
| 101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
| 102 | |||
| 103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
| 104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
| 105 | |||
| 106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 107 | |||
| 108 | mapping.append({"old": old_item, "new": new_item}) | ||
| 109 | |||
| 110 | return mapping | ||
| 111 | |||
| 112 | |||
| 113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 114 | """ | ||
| 115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 116 | """ | ||
| 117 | mapping = [] | ||
| 118 | for old_item in old_list: | ||
| 119 | new_item = old_item | ||
| 120 | |||
| 121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
| 122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
| 123 | |||
| 124 | new_item = new_item.replace("q.weight", "query.weight") | ||
| 125 | new_item = new_item.replace("q.bias", "query.bias") | ||
| 126 | |||
| 127 | new_item = new_item.replace("k.weight", "key.weight") | ||
| 128 | new_item = new_item.replace("k.bias", "key.bias") | ||
| 129 | |||
| 130 | new_item = new_item.replace("v.weight", "value.weight") | ||
| 131 | new_item = new_item.replace("v.bias", "value.bias") | ||
| 132 | |||
| 133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
| 134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
| 135 | |||
| 136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 137 | |||
| 138 | mapping.append({"old": old_item, "new": new_item}) | ||
| 139 | |||
| 140 | return mapping | ||
| 141 | |||
| 142 | |||
| 143 | def assign_to_checkpoint( | ||
| 144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
| 145 | ): | ||
| 146 | """ | ||
| 147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
| 148 | to them. It splits attention layers, and takes into account additional replacements | ||
| 149 | that may arise. | ||
| 150 | |||
| 151 | Assigns the weights to the new checkpoint. | ||
| 152 | """ | ||
| 153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
| 154 | |||
| 155 | # Splits the attention layers into three variables. | ||
| 156 | if attention_paths_to_split is not None: | ||
| 157 | for path, path_map in attention_paths_to_split.items(): | ||
| 158 | old_tensor = old_checkpoint[path] | ||
| 159 | channels = old_tensor.shape[0] // 3 | ||
| 160 | |||
| 161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
| 162 | |||
| 163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
| 164 | |||
| 165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
| 166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
| 167 | |||
| 168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
| 169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
| 170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
| 171 | |||
| 172 | for path in paths: | ||
| 173 | new_path = path["new"] | ||
| 174 | |||
| 175 | # These have already been assigned | ||
| 176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
| 177 | continue | ||
| 178 | |||
| 179 | # Global renaming happens here | ||
| 180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
| 181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
| 182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
| 183 | |||
| 184 | if additional_replacements is not None: | ||
| 185 | for replacement in additional_replacements: | ||
| 186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
| 187 | |||
| 188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
| 189 | if "proj_attn.weight" in new_path: | ||
| 190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
| 191 | else: | ||
| 192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
| 193 | |||
| 194 | |||
| 195 | def conv_attn_to_linear(checkpoint): | ||
| 196 | keys = list(checkpoint.keys()) | ||
| 197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
| 198 | for key in keys: | ||
| 199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
| 200 | if checkpoint[key].ndim > 2: | ||
| 201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
| 202 | elif "proj_attn.weight" in key: | ||
| 203 | if checkpoint[key].ndim > 2: | ||
| 204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
| 205 | |||
| 206 | |||
| 207 | def create_unet_diffusers_config(original_config): | ||
| 208 | """ | ||
| 209 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 210 | """ | ||
| 211 | unet_params = original_config.model.params.unet_config.params | ||
| 212 | |||
| 213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
| 214 | |||
| 215 | down_block_types = [] | ||
| 216 | resolution = 1 | ||
| 217 | for i in range(len(block_out_channels)): | ||
| 218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
| 219 | down_block_types.append(block_type) | ||
| 220 | if i != len(block_out_channels) - 1: | ||
| 221 | resolution *= 2 | ||
| 222 | |||
| 223 | up_block_types = [] | ||
| 224 | for i in range(len(block_out_channels)): | ||
| 225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
| 226 | up_block_types.append(block_type) | ||
| 227 | resolution //= 2 | ||
| 228 | |||
| 229 | config = dict( | ||
| 230 | sample_size=unet_params.image_size, | ||
| 231 | in_channels=unet_params.in_channels, | ||
| 232 | out_channels=unet_params.out_channels, | ||
| 233 | down_block_types=tuple(down_block_types), | ||
| 234 | up_block_types=tuple(up_block_types), | ||
| 235 | block_out_channels=tuple(block_out_channels), | ||
| 236 | layers_per_block=unet_params.num_res_blocks, | ||
| 237 | cross_attention_dim=unet_params.context_dim, | ||
| 238 | attention_head_dim=unet_params.num_heads, | ||
| 239 | ) | ||
| 240 | |||
| 241 | return config | ||
| 242 | |||
| 243 | |||
| 244 | def create_vae_diffusers_config(original_config): | ||
| 245 | """ | ||
| 246 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 247 | """ | ||
| 248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
| 249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
| 250 | |||
| 251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
| 252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
| 253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
| 254 | |||
| 255 | config = dict( | ||
| 256 | sample_size=vae_params.resolution, | ||
| 257 | in_channels=vae_params.in_channels, | ||
| 258 | out_channels=vae_params.out_ch, | ||
| 259 | down_block_types=tuple(down_block_types), | ||
| 260 | up_block_types=tuple(up_block_types), | ||
| 261 | block_out_channels=tuple(block_out_channels), | ||
| 262 | latent_channels=vae_params.z_channels, | ||
| 263 | layers_per_block=vae_params.num_res_blocks, | ||
| 264 | ) | ||
| 265 | return config | ||
| 266 | |||
| 267 | |||
| 268 | def create_diffusers_schedular(original_config): | ||
| 269 | schedular = DDIMScheduler( | ||
| 270 | num_train_timesteps=original_config.model.params.timesteps, | ||
| 271 | beta_start=original_config.model.params.linear_start, | ||
| 272 | beta_end=original_config.model.params.linear_end, | ||
| 273 | beta_schedule="scaled_linear", | ||
| 274 | ) | ||
| 275 | return schedular | ||
| 276 | |||
| 277 | |||
| 278 | def create_ldm_bert_config(original_config): | ||
| 279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
| 280 | config = LDMBertConfig( | ||
| 281 | d_model=bert_params.n_embed, | ||
| 282 | encoder_layers=bert_params.n_layer, | ||
| 283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
| 284 | ) | ||
| 285 | return config | ||
| 286 | |||
| 287 | |||
| 288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
| 289 | """ | ||
| 290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
| 291 | """ | ||
| 292 | |||
| 293 | # extract state_dict for UNet | ||
| 294 | unet_state_dict = {} | ||
| 295 | unet_key = "model.diffusion_model." | ||
| 296 | keys = list(checkpoint.keys()) | ||
| 297 | for key in keys: | ||
| 298 | if key.startswith(unet_key): | ||
| 299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
| 300 | |||
| 301 | new_checkpoint = {} | ||
| 302 | |||
| 303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
| 304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
| 305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
| 306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
| 307 | |||
| 308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
| 309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
| 310 | |||
| 311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
| 312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
| 313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
| 314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
| 315 | |||
| 316 | # Retrieves the keys for the input blocks only | ||
| 317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
| 318 | input_blocks = { | ||
| 319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
| 320 | for layer_id in range(num_input_blocks) | ||
| 321 | } | ||
| 322 | |||
| 323 | # Retrieves the keys for the middle blocks only | ||
| 324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
| 325 | middle_blocks = { | ||
| 326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
| 327 | for layer_id in range(num_middle_blocks) | ||
| 328 | } | ||
| 329 | |||
| 330 | # Retrieves the keys for the output blocks only | ||
| 331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
| 332 | output_blocks = { | ||
| 333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
| 334 | for layer_id in range(num_output_blocks) | ||
| 335 | } | ||
| 336 | |||
| 337 | for i in range(1, num_input_blocks): | ||
| 338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
| 339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
| 340 | |||
| 341 | resnets = [ | ||
| 342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
| 343 | ] | ||
| 344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
| 345 | |||
| 346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
| 347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
| 348 | f"input_blocks.{i}.0.op.weight" | ||
| 349 | ) | ||
| 350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
| 351 | f"input_blocks.{i}.0.op.bias" | ||
| 352 | ) | ||
| 353 | |||
| 354 | paths = renew_resnet_paths(resnets) | ||
| 355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 356 | assign_to_checkpoint( | ||
| 357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 358 | ) | ||
| 359 | |||
| 360 | if len(attentions): | ||
| 361 | paths = renew_attention_paths(attentions) | ||
| 362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
| 363 | assign_to_checkpoint( | ||
| 364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 365 | ) | ||
| 366 | |||
| 367 | resnet_0 = middle_blocks[0] | ||
| 368 | attentions = middle_blocks[1] | ||
| 369 | resnet_1 = middle_blocks[2] | ||
| 370 | |||
| 371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
| 372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 373 | |||
| 374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
| 375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 376 | |||
| 377 | attentions_paths = renew_attention_paths(attentions) | ||
| 378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
| 379 | assign_to_checkpoint( | ||
| 380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 381 | ) | ||
| 382 | |||
| 383 | for i in range(num_output_blocks): | ||
| 384 | block_id = i // (config["layers_per_block"] + 1) | ||
| 385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
| 386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
| 387 | output_block_list = {} | ||
| 388 | |||
| 389 | for layer in output_block_layers: | ||
| 390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
| 391 | if layer_id in output_block_list: | ||
| 392 | output_block_list[layer_id].append(layer_name) | ||
| 393 | else: | ||
| 394 | output_block_list[layer_id] = [layer_name] | ||
| 395 | |||
| 396 | if len(output_block_list) > 1: | ||
| 397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
| 398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
| 399 | |||
| 400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
| 401 | paths = renew_resnet_paths(resnets) | ||
| 402 | |||
| 403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 404 | assign_to_checkpoint( | ||
| 405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 406 | ) | ||
| 407 | |||
| 408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
| 409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
| 410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
| 411 | f"output_blocks.{i}.{index}.conv.weight" | ||
| 412 | ] | ||
| 413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
| 414 | f"output_blocks.{i}.{index}.conv.bias" | ||
| 415 | ] | ||
| 416 | |||
| 417 | # Clear attentions as they have been attributed above. | ||
| 418 | if len(attentions) == 2: | ||
| 419 | attentions = [] | ||
| 420 | |||
| 421 | if len(attentions): | ||
| 422 | paths = renew_attention_paths(attentions) | ||
| 423 | meta_path = { | ||
| 424 | "old": f"output_blocks.{i}.1", | ||
| 425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
| 426 | } | ||
| 427 | assign_to_checkpoint( | ||
| 428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 429 | ) | ||
| 430 | else: | ||
| 431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
| 432 | for path in resnet_0_paths: | ||
| 433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
| 434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
| 435 | |||
| 436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
| 437 | |||
| 438 | return new_checkpoint | ||
| 439 | |||
| 440 | |||
| 441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
| 442 | # extract state dict for VAE | ||
| 443 | vae_state_dict = {} | ||
| 444 | vae_key = "first_stage_model." | ||
| 445 | keys = list(checkpoint.keys()) | ||
| 446 | for key in keys: | ||
| 447 | if key.startswith(vae_key): | ||
| 448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
| 449 | |||
| 450 | new_checkpoint = {} | ||
| 451 | |||
| 452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
| 453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
| 454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
| 455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
| 456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
| 457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
| 458 | |||
| 459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
| 460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
| 461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
| 462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
| 463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
| 464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
| 465 | |||
| 466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
| 467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
| 468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
| 469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
| 470 | |||
| 471 | # Retrieves the keys for the encoder down blocks only | ||
| 472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
| 473 | down_blocks = { | ||
| 474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
| 475 | } | ||
| 476 | |||
| 477 | # Retrieves the keys for the decoder up blocks only | ||
| 478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
| 479 | up_blocks = { | ||
| 480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
| 481 | } | ||
| 482 | |||
| 483 | for i in range(num_down_blocks): | ||
| 484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
| 485 | |||
| 486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
| 487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
| 488 | f"encoder.down.{i}.downsample.conv.weight" | ||
| 489 | ) | ||
| 490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
| 491 | f"encoder.down.{i}.downsample.conv.bias" | ||
| 492 | ) | ||
| 493 | |||
| 494 | paths = renew_vae_resnet_paths(resnets) | ||
| 495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
| 496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 497 | |||
| 498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
| 499 | num_mid_res_blocks = 2 | ||
| 500 | for i in range(1, num_mid_res_blocks + 1): | ||
| 501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
| 502 | |||
| 503 | paths = renew_vae_resnet_paths(resnets) | ||
| 504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 506 | |||
| 507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
| 508 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 511 | conv_attn_to_linear(new_checkpoint) | ||
| 512 | |||
| 513 | for i in range(num_up_blocks): | ||
| 514 | block_id = num_up_blocks - 1 - i | ||
| 515 | resnets = [ | ||
| 516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
| 517 | ] | ||
| 518 | |||
| 519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
| 520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
| 521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
| 522 | ] | ||
| 523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
| 524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
| 525 | ] | ||
| 526 | |||
| 527 | paths = renew_vae_resnet_paths(resnets) | ||
| 528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
| 529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 530 | |||
| 531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
| 532 | num_mid_res_blocks = 2 | ||
| 533 | for i in range(1, num_mid_res_blocks + 1): | ||
| 534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
| 535 | |||
| 536 | paths = renew_vae_resnet_paths(resnets) | ||
| 537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 539 | |||
| 540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
| 541 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 544 | conv_attn_to_linear(new_checkpoint) | ||
| 545 | return new_checkpoint | ||
| 546 | |||
| 547 | |||
| 548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
| 549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
| 550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
| 551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
| 552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
| 553 | |||
| 554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
| 555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
| 556 | |||
| 557 | def _copy_linear(hf_linear, pt_linear): | ||
| 558 | hf_linear.weight = pt_linear.weight | ||
| 559 | hf_linear.bias = pt_linear.bias | ||
| 560 | |||
| 561 | def _copy_layer(hf_layer, pt_layer): | ||
| 562 | # copy layer norms | ||
| 563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
| 564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
| 565 | |||
| 566 | # copy attn | ||
| 567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
| 568 | |||
| 569 | # copy MLP | ||
| 570 | pt_mlp = pt_layer[1][1] | ||
| 571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
| 572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
| 573 | |||
| 574 | def _copy_layers(hf_layers, pt_layers): | ||
| 575 | for i, hf_layer in enumerate(hf_layers): | ||
| 576 | if i != 0: | ||
| 577 | i += i | ||
| 578 | pt_layer = pt_layers[i : i + 2] | ||
| 579 | _copy_layer(hf_layer, pt_layer) | ||
| 580 | |||
| 581 | hf_model = LDMBertModel(config).eval() | ||
| 582 | |||
| 583 | # copy embeds | ||
| 584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
| 585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
| 586 | |||
| 587 | # copy layer norm | ||
| 588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
| 589 | |||
| 590 | # copy hidden layers | ||
| 591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
| 592 | |||
| 593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
| 594 | |||
| 595 | return hf_model | ||
| 596 | |||
| 597 | |||
| 598 | if __name__ == "__main__": | ||
| 599 | parser = argparse.ArgumentParser() | ||
| 600 | |||
| 601 | parser.add_argument( | ||
| 602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
| 603 | ) | ||
| 604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
| 605 | parser.add_argument( | ||
| 606 | "--original_config_file", | ||
| 607 | default=None, | ||
| 608 | type=str, | ||
| 609 | help="The YAML config file corresponding to the original architecture.", | ||
| 610 | ) | ||
| 611 | parser.add_argument( | ||
| 612 | "--scheduler_type", | ||
| 613 | default="pndm", | ||
| 614 | type=str, | ||
| 615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
| 616 | ) | ||
| 617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 618 | |||
| 619 | args = parser.parse_args() | ||
| 620 | |||
| 621 | if args.original_config_file is None: | ||
| 622 | os.system( | ||
| 623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
| 624 | ) | ||
| 625 | args.original_config_file = "./v1-inference.yaml" | ||
| 626 | |||
| 627 | original_config = OmegaConf.load(args.original_config_file) | ||
| 628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
| 629 | |||
| 630 | num_train_timesteps = original_config.model.params.timesteps | ||
| 631 | beta_start = original_config.model.params.linear_start | ||
| 632 | beta_end = original_config.model.params.linear_end | ||
| 633 | if args.scheduler_type == "pndm": | ||
| 634 | scheduler = PNDMScheduler( | ||
| 635 | beta_end=beta_end, | ||
| 636 | beta_schedule="scaled_linear", | ||
| 637 | beta_start=beta_start, | ||
| 638 | num_train_timesteps=num_train_timesteps, | ||
| 639 | skip_prk_steps=True, | ||
| 640 | ) | ||
| 641 | elif args.scheduler_type == "lms": | ||
| 642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
| 643 | elif args.scheduler_type == "ddim": | ||
| 644 | scheduler = DDIMScheduler( | ||
| 645 | beta_start=beta_start, | ||
| 646 | beta_end=beta_end, | ||
| 647 | beta_schedule="scaled_linear", | ||
| 648 | clip_sample=False, | ||
| 649 | set_alpha_to_one=False, | ||
| 650 | ) | ||
| 651 | else: | ||
| 652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
| 653 | |||
| 654 | # Convert the UNet2DConditionModel model. | ||
| 655 | unet_config = create_unet_diffusers_config(original_config) | ||
| 656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
| 657 | |||
| 658 | unet = UNet2DConditionModel(**unet_config) | ||
| 659 | unet.load_state_dict(converted_unet_checkpoint) | ||
| 660 | |||
| 661 | # Convert the VAE model. | ||
| 662 | vae_config = create_vae_diffusers_config(original_config) | ||
| 663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
| 664 | |||
| 665 | vae = AutoencoderKL(**vae_config) | ||
| 666 | vae.load_state_dict(converted_vae_checkpoint) | ||
| 667 | |||
| 668 | # Convert the text model. | ||
| 669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
| 670 | if text_model_type == "FrozenCLIPEmbedder": | ||
| 671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
| 673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 675 | pipe = StableDiffusionPipeline( | ||
| 676 | vae=vae, | ||
| 677 | text_encoder=text_model, | ||
| 678 | tokenizer=tokenizer, | ||
| 679 | unet=unet, | ||
| 680 | scheduler=scheduler, | ||
| 681 | safety_checker=safety_checker, | ||
| 682 | feature_extractor=feature_extractor, | ||
| 683 | ) | ||
| 684 | else: | ||
| 685 | text_config = create_ldm_bert_config(original_config) | ||
| 686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
| 687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
| 688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
| 689 | |||
| 690 | pipe.save_pretrained(args.dump_path) | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1b8a3d2..7a33bca 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -110,7 +110,7 @@ def parse_args(): | |||
| 110 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | "--tag_dropout", | 111 | "--tag_dropout", |
| 112 | type=float, | 112 | type=float, |
| 113 | default=0.1, | 113 | default=0, |
| 114 | help="Tag dropout probability.", | 114 | help="Tag dropout probability.", |
| 115 | ) | 115 | ) |
| 116 | parser.add_argument( | 116 | parser.add_argument( |
| @@ -131,6 +131,11 @@ def parse_args(): | |||
| 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
| 132 | ) | 132 | ) |
| 133 | parser.add_argument( | 133 | parser.add_argument( |
| 134 | "--guidance_scale", | ||
| 135 | type=float, | ||
| 136 | default=0, | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 134 | "--num_class_images", | 139 | "--num_class_images", |
| 135 | type=int, | 140 | type=int, |
| 136 | default=0, | 141 | default=0, |
| @@ -178,7 +183,7 @@ def parse_args(): | |||
| 178 | parser.add_argument( | 183 | parser.add_argument( |
| 179 | "--offset_noise_strength", | 184 | "--offset_noise_strength", |
| 180 | type=float, | 185 | type=float, |
| 181 | default=0.15, | 186 | default=0, |
| 182 | help="Perlin offset noise strength.", | 187 | help="Perlin offset noise strength.", |
| 183 | ) | 188 | ) |
| 184 | parser.add_argument( | 189 | parser.add_argument( |
| @@ -557,8 +562,8 @@ def main(): | |||
| 557 | vae=vae, | 562 | vae=vae, |
| 558 | noise_scheduler=noise_scheduler, | 563 | noise_scheduler=noise_scheduler, |
| 559 | dtype=weight_dtype, | 564 | dtype=weight_dtype, |
| 560 | with_prior_preservation=args.num_class_images != 0, | 565 | guidance_scale=args.guidance_scale, |
| 561 | prior_loss_weight=args.prior_loss_weight, | 566 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 562 | no_val=args.valid_set_size == 0, | 567 | no_val=args.valid_set_size == 0, |
| 563 | ) | 568 | ) |
| 564 | 569 | ||
| @@ -570,6 +575,7 @@ def main(): | |||
| 570 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
| 571 | tokenizer=tokenizer, | 576 | tokenizer=tokenizer, |
| 572 | class_subdir=args.class_image_dir, | 577 | class_subdir=args.class_image_dir, |
| 578 | with_guidance=args.guidance_scale != 0, | ||
| 573 | num_class_images=args.num_class_images, | 579 | num_class_images=args.num_class_images, |
| 574 | size=args.resolution, | 580 | size=args.resolution, |
| 575 | num_buckets=args.num_buckets, | 581 | num_buckets=args.num_buckets, |
diff --git a/train_lora.py b/train_lora.py index b16a99b..684d0cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -88,7 +88,7 @@ def parse_args(): | |||
| 88 | parser.add_argument( | 88 | parser.add_argument( |
| 89 | "--num_buckets", | 89 | "--num_buckets", |
| 90 | type=int, | 90 | type=int, |
| 91 | default=0, | 91 | default=2, |
| 92 | help="Number of aspect ratio buckets in either direction.", | 92 | help="Number of aspect ratio buckets in either direction.", |
| 93 | ) | 93 | ) |
| 94 | parser.add_argument( | 94 | parser.add_argument( |
| @@ -111,7 +111,7 @@ def parse_args(): | |||
| 111 | parser.add_argument( | 111 | parser.add_argument( |
| 112 | "--tag_dropout", | 112 | "--tag_dropout", |
| 113 | type=float, | 113 | type=float, |
| 114 | default=0.1, | 114 | default=0, |
| 115 | help="Tag dropout probability.", | 115 | help="Tag dropout probability.", |
| 116 | ) | 116 | ) |
| 117 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -120,6 +120,11 @@ def parse_args(): | |||
| 120 | help="Shuffle tags.", | 120 | help="Shuffle tags.", |
| 121 | ) | 121 | ) |
| 122 | parser.add_argument( | 122 | parser.add_argument( |
| 123 | "--guidance_scale", | ||
| 124 | type=float, | ||
| 125 | default=0, | ||
| 126 | ) | ||
| 127 | parser.add_argument( | ||
| 123 | "--num_class_images", | 128 | "--num_class_images", |
| 124 | type=int, | 129 | type=int, |
| 125 | default=0, | 130 | default=0, |
| @@ -167,7 +172,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 172 | parser.add_argument( |
| 168 | "--offset_noise_strength", | 173 | "--offset_noise_strength", |
| 169 | type=float, | 174 | type=float, |
| 170 | default=0.15, | 175 | default=0, |
| 171 | help="Perlin offset noise strength.", | 176 | help="Perlin offset noise strength.", |
| 172 | ) | 177 | ) |
| 173 | parser.add_argument( | 178 | parser.add_argument( |
| @@ -589,8 +594,8 @@ def main(): | |||
| 589 | vae=vae, | 594 | vae=vae, |
| 590 | noise_scheduler=noise_scheduler, | 595 | noise_scheduler=noise_scheduler, |
| 591 | dtype=weight_dtype, | 596 | dtype=weight_dtype, |
| 592 | with_prior_preservation=args.num_class_images != 0, | 597 | guidance_scale=args.guidance_scale, |
| 593 | prior_loss_weight=args.prior_loss_weight, | 598 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 594 | no_val=args.valid_set_size == 0, | 599 | no_val=args.valid_set_size == 0, |
| 595 | ) | 600 | ) |
| 596 | 601 | ||
| @@ -602,6 +607,7 @@ def main(): | |||
| 602 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
| 603 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
| 604 | class_subdir=args.class_image_dir, | 609 | class_subdir=args.class_image_dir, |
| 610 | with_guidance=args.guidance_scale != 0, | ||
| 605 | num_class_images=args.num_class_images, | 611 | num_class_images=args.num_class_images, |
| 606 | size=args.resolution, | 612 | size=args.resolution, |
| 607 | num_buckets=args.num_buckets, | 613 | num_buckets=args.num_buckets, |
diff --git a/train_ti.py b/train_ti.py index bbc5524..83ad46d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -91,6 +91,11 @@ def parse_args(): | |||
| 91 | action="store_true", | 91 | action="store_true", |
| 92 | ) | 92 | ) |
| 93 | parser.add_argument( | 93 | parser.add_argument( |
| 94 | "--guidance_scale", | ||
| 95 | type=float, | ||
| 96 | default=0, | ||
| 97 | ) | ||
| 98 | parser.add_argument( | ||
| 94 | "--num_class_images", | 99 | "--num_class_images", |
| 95 | type=int, | 100 | type=int, |
| 96 | default=0, | 101 | default=0, |
| @@ -167,7 +172,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 172 | parser.add_argument( |
| 168 | "--tag_dropout", | 173 | "--tag_dropout", |
| 169 | type=float, | 174 | type=float, |
| 170 | default=0.1, | 175 | default=0, |
| 171 | help="Tag dropout probability.", | 176 | help="Tag dropout probability.", |
| 172 | ) | 177 | ) |
| 173 | parser.add_argument( | 178 | parser.add_argument( |
| @@ -190,7 +195,7 @@ def parse_args(): | |||
| 190 | parser.add_argument( | 195 | parser.add_argument( |
| 191 | "--offset_noise_strength", | 196 | "--offset_noise_strength", |
| 192 | type=float, | 197 | type=float, |
| 193 | default=0.15, | 198 | default=0, |
| 194 | help="Perlin offset noise strength.", | 199 | help="Perlin offset noise strength.", |
| 195 | ) | 200 | ) |
| 196 | parser.add_argument( | 201 | parser.add_argument( |
| @@ -651,8 +656,8 @@ def main(): | |||
| 651 | noise_scheduler=noise_scheduler, | 656 | noise_scheduler=noise_scheduler, |
| 652 | dtype=weight_dtype, | 657 | dtype=weight_dtype, |
| 653 | seed=args.seed, | 658 | seed=args.seed, |
| 654 | with_prior_preservation=args.num_class_images != 0, | 659 | guidance_scale=args.guidance_scale, |
| 655 | prior_loss_weight=args.prior_loss_weight, | 660 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 656 | no_val=args.valid_set_size == 0, | 661 | no_val=args.valid_set_size == 0, |
| 657 | strategy=textual_inversion_strategy, | 662 | strategy=textual_inversion_strategy, |
| 658 | num_train_epochs=args.num_train_epochs, | 663 | num_train_epochs=args.num_train_epochs, |
| @@ -705,6 +710,7 @@ def main(): | |||
| 705 | batch_size=args.train_batch_size, | 710 | batch_size=args.train_batch_size, |
| 706 | tokenizer=tokenizer, | 711 | tokenizer=tokenizer, |
| 707 | class_subdir=args.class_image_dir, | 712 | class_subdir=args.class_image_dir, |
| 713 | with_guidance=args.guidance_scale != 0, | ||
| 708 | num_class_images=args.num_class_images, | 714 | num_class_images=args.num_class_images, |
| 709 | size=args.resolution, | 715 | size=args.resolution, |
| 710 | num_buckets=args.num_buckets, | 716 | num_buckets=args.num_buckets, |
diff --git a/training/functional.py b/training/functional.py index 87bb339..d285366 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -274,7 +274,7 @@ def loss_step( | |||
| 274 | noise_scheduler: SchedulerMixin, | 274 | noise_scheduler: SchedulerMixin, |
| 275 | unet: UNet2DConditionModel, | 275 | unet: UNet2DConditionModel, |
| 276 | text_encoder: CLIPTextModel, | 276 | text_encoder: CLIPTextModel, |
| 277 | with_prior_preservation: bool, | 277 | guidance_scale: float, |
| 278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
| 279 | seed: int, | 279 | seed: int, |
| 280 | offset_noise_strength: float, | 280 | offset_noise_strength: float, |
| @@ -283,13 +283,13 @@ def loss_step( | |||
| 283 | eval: bool = False, | 283 | eval: bool = False, |
| 284 | min_snr_gamma: int = 5, | 284 | min_snr_gamma: int = 5, |
| 285 | ): | 285 | ): |
| 286 | # Convert images to latent space | 286 | images = batch["pixel_values"] |
| 287 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 287 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
| 288 | latents = latents * vae.config.scaling_factor | 288 | bsz = images.shape[0] |
| 289 | |||
| 290 | bsz = latents.shape[0] | ||
| 291 | 289 | ||
| 292 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 290 | # Convert images to latent space |
| 291 | latents = vae.encode(images).latent_dist.sample(generator=generator) | ||
| 292 | latents *= vae.config.scaling_factor | ||
| 293 | 293 | ||
| 294 | # Sample noise that we'll add to the latents | 294 | # Sample noise that we'll add to the latents |
| 295 | noise = torch.randn( | 295 | noise = torch.randn( |
| @@ -301,13 +301,13 @@ def loss_step( | |||
| 301 | ) | 301 | ) |
| 302 | 302 | ||
| 303 | if offset_noise_strength != 0: | 303 | if offset_noise_strength != 0: |
| 304 | noise += offset_noise_strength * perlin_noise( | 304 | offset_noise = torch.randn( |
| 305 | latents.shape, | 305 | (latents.shape[0], latents.shape[1], 1, 1), |
| 306 | res=1, | ||
| 307 | dtype=latents.dtype, | 306 | dtype=latents.dtype, |
| 308 | device=latents.device, | 307 | device=latents.device, |
| 309 | generator=generator | 308 | generator=generator |
| 310 | ) | 309 | ).expand(noise.shape) |
| 310 | noise += offset_noise_strength * offset_noise | ||
| 311 | 311 | ||
| 312 | # Sample a random timestep for each image | 312 | # Sample a random timestep for each image |
| 313 | timesteps = torch.randint( | 313 | timesteps = torch.randint( |
| @@ -343,7 +343,13 @@ def loss_step( | |||
| 343 | else: | 343 | else: |
| 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 345 | 345 | ||
| 346 | if with_prior_preservation: | 346 | if guidance_scale != 0: |
| 347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | ||
| 349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | ||
| 350 | |||
| 351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
| 352 | elif prior_loss_weight != 0: | ||
| 347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 353 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 348 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 354 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 349 | target, target_prior = torch.chunk(target, 2, dim=0) | 355 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -607,9 +613,9 @@ def train( | |||
| 607 | checkpoint_frequency: int = 50, | 613 | checkpoint_frequency: int = 50, |
| 608 | milestone_checkpoints: bool = True, | 614 | milestone_checkpoints: bool = True, |
| 609 | global_step_offset: int = 0, | 615 | global_step_offset: int = 0, |
| 610 | with_prior_preservation: bool = False, | 616 | guidance_scale: float = 0.0, |
| 611 | prior_loss_weight: float = 1.0, | 617 | prior_loss_weight: float = 1.0, |
| 612 | offset_noise_strength: float = 0.1, | 618 | offset_noise_strength: float = 0.15, |
| 613 | **kwargs, | 619 | **kwargs, |
| 614 | ): | 620 | ): |
| 615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 621 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -638,7 +644,7 @@ def train( | |||
| 638 | noise_scheduler, | 644 | noise_scheduler, |
| 639 | unet, | 645 | unet, |
| 640 | text_encoder, | 646 | text_encoder, |
| 641 | with_prior_preservation, | 647 | guidance_scale, |
| 642 | prior_loss_weight, | 648 | prior_loss_weight, |
| 643 | seed, | 649 | seed, |
| 644 | offset_noise_strength, | 650 | offset_noise_strength, |
