From b2c3389e9c6375d9081625e75a99de98395f8e77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 1 Nov 2022 16:19:01 +0100 Subject: Update --- data/csv.py | 11 +- dreambooth.py | 2 +- .../stable_diffusion/vlpn_stable_diffusion.py | 5 +- schedulers/scheduling_euler_ancestral_discrete.py | 162 ++++++++++++++------- 4 files changed, 117 insertions(+), 63 deletions(-) diff --git a/data/csv.py b/data/csv.py index 6bd7f9b..793fbf8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -150,7 +150,6 @@ class CSVDataset(Dataset): self.class_identifier = class_identifier self.num_class_images = num_class_images self.image_cache = {} - self.input_id_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -185,15 +184,7 @@ class CSVDataset(Dataset): return image def get_input_ids(self, prompt, identifier): - prompt = prompt.format(identifier) - - if prompt in self.input_id_cache: - return self.input_id_cache[prompt] - - input_ids = self.prompt_processor.get_input_ids(prompt) - self.input_id_cache[prompt] = input_ids - - return input_ids + return self.prompt_processor.get_input_ids(prompt.format(identifier)) def get_example(self, i): item = self.data[i % self.num_instance_images] diff --git a/dreambooth.py b/dreambooth.py index 17107d0..c0caf03 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -210,7 +210,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=5 / 6 + default=7 / 8 ) parser.add_argument( "--ema_max_decay", diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index fc12355..cd5ae7e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -203,6 +203,7 @@ class VlpnStableDiffusion(DiffusionPipeline): # However this currently doesn't work in `mps`. latents_dtype = text_embeddings.dtype latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: if self.device.type == "mps": # randn does not exist on mps @@ -264,7 +265,7 @@ class VlpnStableDiffusion(DiffusionPipeline): for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -275,7 +276,7 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py index 828e0dd..cef50fe 100644 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ b/schedulers/scheduling_euler_ancestral_discrete.py @@ -1,4 +1,4 @@ -# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.utils import BaseOutput, deprecate, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - Ancestral sampling with Euler method steps. - for discrete beta schedules. Based on the original k-diffusion implementation by - Katherine Crowson: + Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` @@ -42,9 +64,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -52,8 +71,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): def __init__( self, num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, + beta_start: float = 0.0001, + beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, ): @@ -76,20 +95,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) - self.init_noise_sigma = None + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() # setable values self.num_inference_steps = None - timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) - self.derivatives = [] self.is_scale_input_called = False def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample @@ -98,8 +117,12 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): Returns: `torch.FloatTensor`: scaled input sample """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): @@ -109,86 +132,125 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) - low_idx = np.floor(self.timesteps).astype(int) - high_idx = np.ceil(self.timesteps).astype(int) - frac = np.mod(self.timesteps, 1.0) + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(self.timesteps) - self.init_noise_sigma = self.sigmas[0] - self.derivatives = [] + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device) def step( self, - model_output: Union[torch.FloatTensor, np.ndarray], + model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], - step_index: Union[int, torch.IntTensor], - sample: Union[torch.FloatTensor, np.ndarray], - generator: torch.Generator = None, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor` or `np.ndarray`): + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep.", + ) + + if not self.is_scale_input_called: + logger.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = sample - sigma * model_output sigma_from = self.sigmas[step_index] sigma_to = self.sigmas[step_index + 1] - sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma - self.derivatives.append(derivative) dt = sigma_down - sigma prev_sample = sample + derivative * dt - prev_sample = prev_sample + torch.randn( - prev_sample.shape, - layout=prev_sample.layout, - device=prev_sample.device, - dtype=prev_sample.dtype, - generator=generator - ) * sigma_up + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) + prev_sample = prev_sample + noise * sigma_up if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - self.timesteps = self.timesteps.to(original_samples.device) - sigma = self.sigmas[timesteps].flatten() + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + schedule_timesteps = self.timesteps + + if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): + deprecate( + "timesteps as indices", + "0.8.0", + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" + " pass values from `scheduler.timesteps` as timesteps.", + standard_warn=False, + ) + step_indices = timesteps + else: + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) -- cgit v1.2.3-54-g00ecf