From 49463992f48ec25f2ea31b220a6cedac3466467a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 26 Oct 2022 11:11:33 +0200 Subject: New Euler_a scheduler --- dreambooth.py | 6 +- infer.py | 16 +- .../stable_diffusion/vlpn_stable_diffusion.py | 27 +- schedulers/scheduling_euler_a.py | 286 --------------------- schedulers/scheduling_euler_ancestral_discrete.py | 192 ++++++++++++++ textual_inversion.py | 6 +- 6 files changed, 215 insertions(+), 318 deletions(-) delete mode 100644 schedulers/scheduling_euler_a.py create mode 100644 schedulers/scheduling_euler_ancestral_discrete.py diff --git a/dreambooth.py b/dreambooth.py index 2c24908..a181293 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -23,7 +23,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from models.clip.prompt import PromptProcessor @@ -443,7 +443,7 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -715,7 +715,7 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) diff --git a/infer.py b/infer.py index 01010eb..ac05955 100644 --- a/infer.py +++ b/infer.py @@ -12,7 +12,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMSc from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -175,16 +175,8 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): embeddings_dir = Path(embeddings_dir) embeddings_dir.mkdir(parents=True, exist_ok=True) - for file in embeddings_dir.iterdir(): - if file.is_file(): - placeholder_token = file.stem - - num_added_tokens = tokenizer.add_tokens(placeholder_token) - if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) + placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] + tokenizer.add_tokens(placeholder_tokens) text_encoder.resize_token_embeddings(len(tokenizer)) @@ -231,7 +223,7 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) else: - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index e90528d..fc12355 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], **kwargs, ): super().__init__() @@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline): init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) @@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline): for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) - noise_pred = None - if isinstance(self.scheduler, EulerAScheduler): - c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) - eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample - noise_pred = latent_model_input + eps * c_out - else: - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py deleted file mode 100644 index c097a8a..0000000 --- a/schedulers/scheduling_euler_a.py +++ /dev/null @@ -1,286 +0,0 @@ -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput - - -class EulerAScheduler(SchedulerMixin, ConfigMixin): - """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. - - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - num_inference_steps=None, - device='cuda' - ): - if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas).to(device) - if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, - dtype=torch.float32, device=device) ** 2 - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.device = device - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = num_inference_steps - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - # get sigmas - self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) - - # A# take number of steps as input - # A# store 1) number of steps 2) timesteps 3) schedule - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - self.num_inference_steps = num_inference_steps - self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) - self.timesteps = self.sigmas[:-1] - self.is_scale_input_called = False - - def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - if self.is_scale_input_called: - return sample - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - sample = sample * sigma - self.is_scale_input_called = True - return sample - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: torch.Generator = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). - Returns: - [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: - [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - step_prev_index = step_index + 1 - - s = self.sigmas[step_index] - s_prev = self.sigmas[step_prev_index] - latents = sample - - sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) - d = self.to_d(latents, s, model_output) - dt = sigma_down - s - latents = latents + d * dt - latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, - generator=generator) * sigma_up - - return SchedulerOutput(prev_sample=latents) - - def step_correct( - self, - model_output: torch.FloatTensor, - sigma_hat: float, - sigma_prev: float, - sample_hat: torch.FloatTensor, - sample_prev: torch.FloatTensor, - derivative: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. TODO complete description - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - sample_prev (`torch.FloatTensor`): TODO - derivative (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO - - """ - pred_original_sample = sample_prev + sigma_prev * model_output - derivative_corr = (sample_prev - pred_original_sample) / sigma_prev - sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - - if not return_dict: - return (sample_prev, derivative) - - return SchedulerOutput(prev_sample=sample_prev) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - sigmas = self.sigmas.to(original_samples.device) - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - self.is_scale_input_called = True - return noisy_samples - - # from k_samplers sampling.py - - def get_ancestral_step(self, sigma_from, sigma_to): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - return sigma_down, sigma_up - - def t_to_sigma(self, t, sigmas): - t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() - return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] - - def append_zero(self, x): - return torch.cat([x, x.new_zeros([1])]) - - def get_sigmas(self, sigmas, n=None): - if n is None: - return self.append_zero(sigmas.flip(0)) - t_max = len(sigmas) - 1 # = 999 - device = self.device - t = torch.linspace(t_max, 0, n, device=device) - # t = torch.linspace(t_max, 0, n, device=sigmas.device) - return self.append_zero(self.t_to_sigma(t, sigmas)) - - # from k_samplers utils.py - def append_dims(self, x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - # from k_samplers sampling.py - def to_d(self, x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / self.append_dims(sigma, x.ndim) - - def get_scalings(self, sigma): - sigma_data = 1. - c_out = -sigma - c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 - return c_out, c_in - - # DiscreteSchedule DS - def DSsigma_to_t(self, sigma, quantize=None): - # quantize = self.quantize if quantize is None else quantize - quantize = False - dists = torch.abs(sigma - self.DSsigmas[:, None]) - if quantize: - return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] - low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] - w = (low - sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def prepare_input(self, latent_in, t, batch_size): - sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 - - sigma_in = torch.cat([sigma] * 2 * batch_size) - # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) - # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) - c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] - - sigma_in = self.DSsigma_to_t(sigma_in) - # s_in = latent_in.new_ones([latent_in.shape[0]]) - # sigma_in = sigma_in * s_in - - return c_out, c_in, sigma_in diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000..3a2de68 --- /dev/null +++ b/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,192 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None + + # setable values + self.num_inference_steps = None + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.derivatives = [] + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] + self.derivatives = [] + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/textual_inversion.py b/textual_inversion.py index bcdfd3a..dd7c3bd 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -22,7 +22,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_a import EulerAScheduler +from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from models.clip.prompt import PromptProcessor @@ -398,7 +398,7 @@ class Checkpointer: samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -639,7 +639,7 @@ def main(): batched_data = [missing_data[i:i+args.sample_batch_size] for i in range(0, len(missing_data), args.sample_batch_size)] - scheduler = EulerAScheduler( + scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) -- cgit v1.2.3-54-g00ecf