diff options
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 169 | ||||
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 2 |
4 files changed, 164 insertions, 17 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index dab7878..66566b0 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -1,9 +1,11 @@ | |||
| 1 | import inspect | 1 | import inspect |
| 2 | import warnings | 2 | import warnings |
| 3 | import math | ||
| 3 | from typing import List, Dict, Any, Optional, Union, Callable | 4 | from typing import List, Dict, Any, Optional, Union, Callable |
| 4 | 5 | ||
| 5 | import numpy as np | 6 | import numpy as np |
| 6 | import torch | 7 | import torch |
| 8 | import torchvision.transforms as T | ||
| 7 | import PIL | 9 | import PIL |
| 8 | 10 | ||
| 9 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
| @@ -37,6 +39,35 @@ def preprocess(image): | |||
| 37 | return 2.0 * image - 1.0 | 39 | return 2.0 * image - 1.0 |
| 38 | 40 | ||
| 39 | 41 | ||
| 42 | class CrossAttnStoreProcessor: | ||
| 43 | def __init__(self): | ||
| 44 | self.attention_probs = None | ||
| 45 | |||
| 46 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
| 47 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
| 49 | |||
| 50 | query = attn.to_q(hidden_states) | ||
| 51 | query = attn.head_to_batch_dim(query) | ||
| 52 | |||
| 53 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 54 | key = attn.to_k(encoder_hidden_states) | ||
| 55 | value = attn.to_v(encoder_hidden_states) | ||
| 56 | key = attn.head_to_batch_dim(key) | ||
| 57 | value = attn.head_to_batch_dim(value) | ||
| 58 | |||
| 59 | self.attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
| 60 | hidden_states = torch.bmm(self.attention_probs, value) | ||
| 61 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 62 | |||
| 63 | # linear proj | ||
| 64 | hidden_states = attn.to_out[0](hidden_states) | ||
| 65 | # dropout | ||
| 66 | hidden_states = attn.to_out[1](hidden_states) | ||
| 67 | |||
| 68 | return hidden_states | ||
| 69 | |||
| 70 | |||
| 40 | class VlpnStableDiffusion(DiffusionPipeline): | 71 | class VlpnStableDiffusion(DiffusionPipeline): |
| 41 | def __init__( | 72 | def __init__( |
| 42 | self, | 73 | self, |
| @@ -233,9 +264,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 233 | else: | 264 | else: |
| 234 | attention_mask = None | 265 | attention_mask = None |
| 235 | 266 | ||
| 236 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) | 267 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
| 237 | 268 | ||
| 238 | return text_embeddings | 269 | return prompt_embeds |
| 239 | 270 | ||
| 240 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): | 271 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): |
| 241 | if latents_are_image: | 272 | if latents_are_image: |
| @@ -330,6 +361,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 330 | width: Optional[int] = None, | 361 | width: Optional[int] = None, |
| 331 | num_inference_steps: int = 50, | 362 | num_inference_steps: int = 50, |
| 332 | guidance_scale: float = 7.5, | 363 | guidance_scale: float = 7.5, |
| 364 | sag_scale: float = 0.75, | ||
| 333 | eta: float = 0.0, | 365 | eta: float = 0.0, |
| 334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 366 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 367 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| @@ -403,10 +435,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 403 | batch_size = len(prompt) | 435 | batch_size = len(prompt) |
| 404 | device = self.execution_device | 436 | device = self.execution_device |
| 405 | do_classifier_free_guidance = guidance_scale > 1.0 | 437 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 438 | do_self_attention_guidance = sag_scale > 0.0 | ||
| 406 | latents_are_image = isinstance(image, PIL.Image.Image) | 439 | latents_are_image = isinstance(image, PIL.Image.Image) |
| 407 | 440 | ||
| 408 | # 3. Encode input prompt | 441 | # 3. Encode input prompt |
| 409 | text_embeddings = self.encode_prompt( | 442 | prompt_embeds = self.encode_prompt( |
| 410 | prompt, | 443 | prompt, |
| 411 | negative_prompt, | 444 | negative_prompt, |
| 412 | num_images_per_prompt, | 445 | num_images_per_prompt, |
| @@ -427,7 +460,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 427 | image, | 460 | image, |
| 428 | latent_timestep, | 461 | latent_timestep, |
| 429 | batch_size * num_images_per_prompt, | 462 | batch_size * num_images_per_prompt, |
| 430 | text_embeddings.dtype, | 463 | prompt_embeds.dtype, |
| 431 | device, | 464 | device, |
| 432 | generator | 465 | generator |
| 433 | ) | 466 | ) |
| @@ -437,7 +470,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 437 | num_channels_latents, | 470 | num_channels_latents, |
| 438 | height, | 471 | height, |
| 439 | width, | 472 | width, |
| 440 | text_embeddings.dtype, | 473 | prompt_embeds.dtype, |
| 441 | device, | 474 | device, |
| 442 | generator, | 475 | generator, |
| 443 | image, | 476 | image, |
| @@ -446,7 +479,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 446 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 479 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
| 447 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 480 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| 448 | 481 | ||
| 449 | # 7. Denoising loop | 482 | # 7. Denoising loo |
| 483 | if do_self_attention_guidance: | ||
| 484 | store_processor = CrossAttnStoreProcessor() | ||
| 485 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | ||
| 486 | |||
| 450 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 487 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| 451 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 488 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 452 | for i, t in enumerate(timesteps): | 489 | for i, t in enumerate(timesteps): |
| @@ -458,7 +495,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 458 | noise_pred = self.unet( | 495 | noise_pred = self.unet( |
| 459 | latent_model_input, | 496 | latent_model_input, |
| 460 | t, | 497 | t, |
| 461 | encoder_hidden_states=text_embeddings, | 498 | encoder_hidden_states=prompt_embeds, |
| 462 | cross_attention_kwargs=cross_attention_kwargs, | 499 | cross_attention_kwargs=cross_attention_kwargs, |
| 463 | ).sample | 500 | ).sample |
| 464 | 501 | ||
| @@ -467,6 +504,36 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 467 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 504 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 468 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 505 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 469 | 506 | ||
| 507 | if do_self_attention_guidance: | ||
| 508 | # classifier-free guidance produces two chunks of attention map | ||
| 509 | # and we only use unconditional one according to equation (24) | ||
| 510 | # in https://arxiv.org/pdf/2210.00939.pdf | ||
| 511 | if do_classifier_free_guidance: | ||
| 512 | # DDIM-like prediction of x0 | ||
| 513 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) | ||
| 514 | # get the stored attention maps | ||
| 515 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | ||
| 516 | # self-attention-based degrading of latents | ||
| 517 | degraded_latents = self.sag_masking( | ||
| 518 | pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) | ||
| 519 | ) | ||
| 520 | uncond_emb, _ = prompt_embeds.chunk(2) | ||
| 521 | # forward and give guidance | ||
| 522 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample | ||
| 523 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | ||
| 524 | else: | ||
| 525 | # DDIM-like prediction of x0 | ||
| 526 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) | ||
| 527 | # get the stored attention maps | ||
| 528 | cond_attn = store_processor.attention_probs | ||
| 529 | # self-attention-based degrading of latents | ||
| 530 | degraded_latents = self.sag_masking( | ||
| 531 | pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) | ||
| 532 | ) | ||
| 533 | # forward and give guidance | ||
| 534 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | ||
| 535 | noise_pred += sag_scale * (noise_pred - degraded_pred) | ||
| 536 | |||
| 470 | # compute the previous noisy sample x_t -> x_t-1 | 537 | # compute the previous noisy sample x_t -> x_t-1 |
| 471 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 538 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| 472 | 539 | ||
| @@ -490,3 +557,91 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 490 | return (image, has_nsfw_concept) | 557 | return (image, has_nsfw_concept) |
| 491 | 558 | ||
| 492 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 559 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| 560 | |||
| 561 | # Self-Attention-Guided (SAG) Stable Diffusion | ||
| 562 | |||
| 563 | def sag_masking(self, original_latents, attn_map, t, eps): | ||
| 564 | # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf | ||
| 565 | bh, hw1, hw2 = attn_map.shape | ||
| 566 | b, latent_channel, latent_h, latent_w = original_latents.shape | ||
| 567 | h = self.unet.attention_head_dim | ||
| 568 | if isinstance(h, list): | ||
| 569 | h = h[-1] | ||
| 570 | map_size = math.isqrt(hw1) | ||
| 571 | |||
| 572 | # Produce attention mask | ||
| 573 | attn_map = attn_map.reshape(b, h, hw1, hw2) | ||
| 574 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | ||
| 575 | attn_mask = ( | ||
| 576 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | ||
| 577 | ) | ||
| 578 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | ||
| 579 | |||
| 580 | # Blur according to the self-attention mask | ||
| 581 | transform = T.GaussianBlur(kernel_size=9, sigma=1.0) | ||
| 582 | degraded_latents = transform(original_latents) | ||
| 583 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | ||
| 584 | |||
| 585 | # Noise it again to match the noise level | ||
| 586 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | ||
| 587 | |||
| 588 | return degraded_latents | ||
| 589 | |||
| 590 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step | ||
| 591 | def pred_x0_from_eps(self, sample, model_output, timestep): | ||
| 592 | # 1. get previous step value (=t-1) | ||
| 593 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
| 594 | |||
| 595 | # 2. compute alphas, betas | ||
| 596 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
| 597 | # alpha_prod_t_prev = ( | ||
| 598 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
| 599 | # ) | ||
| 600 | |||
| 601 | beta_prod_t = 1 - alpha_prod_t | ||
| 602 | # 3. compute predicted original sample from predicted noise also called | ||
| 603 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
| 604 | if self.scheduler.config.prediction_type == "epsilon": | ||
| 605 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||
| 606 | elif self.scheduler.config.prediction_type == "sample": | ||
| 607 | pred_original_sample = model_output | ||
| 608 | elif self.scheduler.config.prediction_type == "v_prediction": | ||
| 609 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||
| 610 | # predict V | ||
| 611 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||
| 612 | else: | ||
| 613 | raise ValueError( | ||
| 614 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | ||
| 615 | " or `v_prediction`" | ||
| 616 | ) | ||
| 617 | # # 4. Clip "predicted x_0" | ||
| 618 | # if self.scheduler.config.clip_sample: | ||
| 619 | # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | ||
| 620 | |||
| 621 | return pred_original_sample | ||
| 622 | |||
| 623 | def pred_eps_from_noise(self, sample, model_output, timestep): | ||
| 624 | # 1. get previous step value (=t-1) | ||
| 625 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
| 626 | |||
| 627 | # 2. compute alphas, betas | ||
| 628 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
| 629 | # alpha_prod_t_prev = ( | ||
| 630 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
| 631 | # ) | ||
| 632 | |||
| 633 | beta_prod_t = 1 - alpha_prod_t | ||
| 634 | # 3. compute predicted eps from model output | ||
| 635 | if self.scheduler.config.prediction_type == "epsilon": | ||
| 636 | pred_eps = model_output | ||
| 637 | elif self.scheduler.config.prediction_type == "sample": | ||
| 638 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | ||
| 639 | elif self.scheduler.config.prediction_type == "v_prediction": | ||
| 640 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | ||
| 641 | else: | ||
| 642 | raise ValueError( | ||
| 643 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | ||
| 644 | " or `v_prediction`" | ||
| 645 | ) | ||
| 646 | |||
| 647 | return pred_eps | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index a29c507..8ac70e8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -304,7 +304,7 @@ def parse_args(): | |||
| 304 | parser.add_argument( | 304 | parser.add_argument( |
| 305 | "--adam_weight_decay", | 305 | "--adam_weight_decay", |
| 306 | type=float, | 306 | type=float, |
| 307 | default=1e-2, | 307 | default=0, |
| 308 | help="Weight decay to use." | 308 | help="Weight decay to use." |
| 309 | ) | 309 | ) |
| 310 | parser.add_argument( | 310 | parser.add_argument( |
diff --git a/train_lora.py b/train_lora.py index ab1753b..5fd05cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -178,11 +178,6 @@ def parse_args(): | |||
| 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 179 | ) | 179 | ) |
| 180 | parser.add_argument( | 180 | parser.add_argument( |
| 181 | "--gradient_checkpointing", | ||
| 182 | action="store_true", | ||
| 183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
| 184 | ) | ||
| 185 | parser.add_argument( | ||
| 186 | "--find_lr", | 181 | "--find_lr", |
| 187 | action="store_true", | 182 | action="store_true", |
| 188 | help="Automatically find a learning rate (no training).", | 183 | help="Automatically find a learning rate (no training).", |
| @@ -429,9 +424,6 @@ def main(): | |||
| 429 | vae.set_use_memory_efficient_attention_xformers(True) | 424 | vae.set_use_memory_efficient_attention_xformers(True) |
| 430 | unet.enable_xformers_memory_efficient_attention() | 425 | unet.enable_xformers_memory_efficient_attention() |
| 431 | 426 | ||
| 432 | if args.gradient_checkpointing: | ||
| 433 | unet.enable_gradient_checkpointing() | ||
| 434 | |||
| 435 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | unet.to(accelerator.device, dtype=weight_dtype) |
| 436 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| 437 | 429 | ||
diff --git a/train_ti.py b/train_ti.py index 2840def..c79dfa2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -143,7 +143,7 @@ def parse_args(): | |||
| 143 | parser.add_argument( | 143 | parser.add_argument( |
| 144 | "--num_buckets", | 144 | "--num_buckets", |
| 145 | type=int, | 145 | type=int, |
| 146 | default=0, | 146 | default=4, |
| 147 | help="Number of aspect ratio buckets in either direction.", | 147 | help="Number of aspect ratio buckets in either direction.", |
| 148 | ) | 148 | ) |
| 149 | parser.add_argument( | 149 | parser.add_argument( |
