From 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Sep 2022 14:13:51 +0200 Subject: Added custom SD pipeline + euler_a scheduler --- .gitignore | 2 +- dreambooth.py | 12 +- infer.py | 111 +++-- .../clip_guided_stable_diffusion.py | 457 +++++++++++++++++++++ schedulers/scheduling_euler_a.py | 323 +++++++++++++++ 5 files changed, 869 insertions(+), 36 deletions(-) create mode 100644 pipelines/stable_diffusion/clip_guided_stable_diffusion.py create mode 100644 schedulers/scheduling_euler_a.py diff --git a/.gitignore b/.gitignore index 4456cef..35b4c22 100644 --- a/.gitignore +++ b/.gitignore @@ -160,5 +160,5 @@ cython_debug/ #.idea/ output/ -conf*.json +conf/ v1-inference.yaml* diff --git a/dreambooth.py b/dreambooth.py index 39c4851..4d7366c 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( "--repeats", type=int, - default=100, + default=1, help="How many times to repeat the training data." ) parser.add_argument( @@ -375,7 +375,6 @@ class Checkpointer: @torch.no_grad() def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - samples_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( @@ -403,6 +402,7 @@ class Checkpointer: all_samples = [] file_path = samples_path.joinpath("stable", f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(val_data) @@ -436,6 +436,7 @@ class Checkpointer: for data, pool in [(val_data, "val"), (train_data, "train")]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) @@ -496,11 +497,15 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.float32 + if accelerator.device.type == "cuda": + torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] + pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(disable=True) + pipeline.to(accelerator.device) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") @@ -509,7 +514,6 @@ def main(): sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process diff --git a/infer.py b/infer.py index f2007e9..de3d792 100644 --- a/infer.py +++ b/infer.py @@ -1,18 +1,15 @@ import argparse import datetime +import logging from pathlib import Path from torch import autocast -from diffusers import StableDiffusionPipeline import torch import json -from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler -from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler +from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor from slugify import slugify -from pipelines.stable_diffusion.no_check import NoCheck - -model_id = "path-to-your-trained-model" - -prompt = "A photo of sks dog in a bucket" +from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler def parse_args(): @@ -29,6 +26,21 @@ def parse_args(): type=str, default=None, ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + ) + parser.add_argument( + "--width", + type=int, + default=512, + ) + parser.add_argument( + "--height", + type=int, + default=512, + ) parser.add_argument( "--batch_size", type=int, @@ -42,17 +54,28 @@ def parse_args(): parser.add_argument( "--steps", type=int, - default=80, + default=120, ) parser.add_argument( - "--scale", + "--scheduler", + type=str, + choices=["plms", "ddim", "klms", "euler_a"], + default="euler_a", + ) + parser.add_argument( + "--guidance_scale", type=int, default=7.5, ) + parser.add_argument( + "--clip_guidance_scale", + type=int, + default=100, + ) parser.add_argument( "--seed", type=int, - default=None, + default=torch.random.seed(), ) parser.add_argument( "--output_dir", @@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def main(): - args = parse_args() - - seed = args.seed or torch.random.seed() - - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") - output_dir.mkdir(parents=True, exist_ok=True) - save_args(output_dir, args) - +def gen(args, output_dir): tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) + clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) - feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) + feature_extractor = CLIPFeatureExtractor.from_pretrained( + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) - pipeline = StableDiffusionPipeline( + if args.scheduler == "plms": + scheduler = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + elif args.scheduler == "klms": + scheduler = LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + elif args.scheduler == "ddim": + scheduler = DDIMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False + ) + else: + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False + ) + + pipeline = CLIPGuidedStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=NoCheck(), + clip_model=clip_model, + scheduler=scheduler, feature_extractor=feature_extractor ) pipeline.enable_attention_slicing() @@ -113,16 +144,34 @@ def main(): with autocast("cuda"): for i in range(args.batch_num): - generator = torch.Generator(device="cuda").manual_seed(seed + i) + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( - [args.prompt] * args.batch_size, + prompt=[args.prompt] * args.batch_size, + height=args.height, + width=args.width, + negative_prompt=args.negative_prompt, num_inference_steps=args.steps, - guidance_scale=args.scale, + guidance_scale=args.guidance_scale, + clip_guidance_scale=args.clip_guidance_scale, generator=generator, ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") + output_dir.mkdir(parents=True, exist_ok=True) + + save_args(output_dir, args) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + gen(args, output_dir) if __name__ == "__main__": diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py new file mode 100644 index 0000000..306d9a9 --- /dev/null +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py @@ -0,0 +1,457 @@ +import inspect +import warnings +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import FrozenDict +from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import logging +from torchvision import transforms +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer +from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +def set_requires_grad(model, value): + for param in model.parameters(): + param.requires_grad = value + + +class CLIPGuidedStableDiffusion(DiffusionPipeline): + """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 + - https://github.com/Jack000/glid-3-xl + - https://github.dev/crowsonkb/k-diffusion + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + clip_model: CLIPModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + feature_extractor: CLIPFeatureExtractor, + **kwargs, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + clip_model=clip_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + self.make_cutouts = MakeCutouts(feature_extractor.size) + + set_requires_grad(self.text_encoder, False) + set_requires_grad(self.clip_model, False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def freeze_vae(self): + set_requires_grad(self.vae, False) + + def unfreeze_vae(self): + set_requires_grad(self.vae, True) + + def freeze_unet(self): + set_requires_grad(self.unet, False) + + def unfreeze_unet(self): + set_requires_grad(self.unet, True) + + @torch.enable_grad() + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, PNDMScheduler): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(self.feature_extractor.size)(image) + image = self.normalize(image) + + image_embeddings_clip = self.clip_model.get_image_features(image).float() + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + clip_guidance_scale: Optional[float] = 100, + clip_prompt: Optional[Union[str, List[str]]] = None, + num_cutouts: Optional[int] = 4, + use_cutouts: Optional[bool] = True, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + elif isinstance(negative_prompt, list): + if len(negative_prompt) != batch_size: + raise ValueError( + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") + else: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + if clip_guidance_scale > 0: + if clip_prompt is not None: + clip_text_inputs = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + clip_text_input_ids = clip_text_inputs.input_ids + else: + clip_text_input_ids = text_input_ids + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=latents_device, + dtype=text_embeddings.dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimzed to move all timesteps to correct device beforehand + if torch.is_tensor(self.scheduler.timesteps): + timesteps_tensor = self.scheduler.timesteps.to(self.device) + else: + timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + elif isinstance(self.scheduler, EulerAScheduler): + sigma = self.scheduler.timesteps[0] + latents = latents * sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in scheduler_step_args + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + accepts_generator = "generator" in scheduler_step_args + if generator is not None and accepts_generator: + extra_step_kwargs["generator"] = generator + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + noise_pred = None + if isinstance(self.scheduler, EulerAScheduler): + sigma = t.reshape(1) + sigma_in = torch.cat([sigma] * 2) + # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) + noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, + text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) + # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample + else: + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform clip guidance + if clip_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings + ) + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + elif isinstance(self.scheduler, EulerAScheduler): + if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error + t_prev = self.scheduler.timesteps[i+1] + latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py new file mode 100644 index 0000000..57a56de --- /dev/null +++ b/schedulers/scheduling_euler_a.py @@ -0,0 +1,323 @@ + + +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + + +''' +helper functions: append_zero(), + t_to_sigma(), + get_sigmas(), + append_dims(), + CFGDenoiserForward(), + get_scalings(), + DSsigma_to_t(), + DiscreteEpsDDPMDenoiserForward(), + to_d(), + get_ancestral_step() +need cleaning +''' + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def t_to_sigma(t, sigmas): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] + + +def get_sigmas(sigmas, n=None): + if n is None: + return append_zero(sigmas.flip(0)) + t_max = len(sigmas) - 1 # = 999 + t = torch.linspace(t_max, 0, n, device=sigmas.device) + # t = torch.linspace(t_max, 0, n, device=sigmas.device) + return append_zero(t_to_sigma(t, sigmas)) + +# from k_samplers utils.py + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): + # x_in = torch.cat([x] * 2)#A# concat the latent + # sigma_in = torch.cat([sigma] * 2) #A# concat sigma + # cond_in = torch.cat([uncond, cond]) + # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) + # return uncond + (cond - uncond) * cond_scale + noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) + return noise_pred + +# from k_samplers sampling.py + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) + + +def get_scalings(sigma): + sigma_data = 1. + c_out = -sigma + c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 + return c_out, c_in + +# DiscreteSchedule DS + + +def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): + # quantize = self.quantize if quantize is None else quantize + quantize = False + dists = torch.abs(sigma - DSsigmas[:, None]) + if quantize: + return torch.argmin(dists, dim=0).view(sigma.shape) + low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + low, high = DSsigmas[low_idx], DSsigmas[high_idx] + w = (low - sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + +def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): + sigma = sigma.to(Unet.device) + DSsigmas = DSsigmas.to(Unet.device) + c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] + # ??? what is eps? + # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) + eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), + encoder_hidden_states=kwargs['cond']).sample + return input + eps * c_out + + +# from k_samplers sampling.py +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +''' +Euler Ancestral Scheduler +''' + + +class EulerAScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + if beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1] + + # A# take number of steps as input + # A# store 1) number of steps 2) timesteps 3) schedule + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + # offset = self.config.steps_offset + + # if "offset" in kwargs: + # warnings.warn( + # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + # " Please pass `steps_offset` to `__init__` instead.", + # DeprecationWarning, + # ) + + # offset = kwargs["offset"] + + self.num_inference_steps = num_inference_steps + self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) + self.timesteps = self.sigmas + + def add_noise_to_input( + self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None + ) -> Tuple[torch.FloatTensor, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.IntTensor, + timestep_prev: torch.IntTensor, + sample: torch.FloatTensor, + generator: None, + # ,sigma_hat: float, + # sigma_prev: float, + # sample_hat: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + latents = sample + sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) + + # if callback is not None: + # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) + d = to_d(latents, timestep, model_output) + # Euler method + dt = sigma_down - timestep + latents = latents + d * dt + latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, + generator=generator) * sigma_up + return SchedulerOutput(prev_sample=latents) + + def step_correct( + self, + model_output: torch.FloatTensor, + sigma_hat: float, + sigma_prev: float, + sample_hat: torch.FloatTensor, + sample_prev: torch.FloatTensor, + derivative: torch.FloatTensor, + generator: None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor`): TODO + sample_prev (`torch.FloatTensor`): TODO + derivative (`torch.FloatTensor`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative) + + return SchedulerOutput(prev_sample=sample_prev) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() -- cgit v1.2.3-54-g00ecf