diff options
| -rw-r--r-- | dreambooth.py | 6 | ||||
| -rw-r--r-- | infer.py | 16 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 286 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 192 | ||||
| -rw-r--r-- | textual_inversion.py | 6 |
6 files changed, 215 insertions, 318 deletions
diff --git a/dreambooth.py b/dreambooth.py index 2c24908..a181293 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -23,7 +23,7 @@ from tqdm.auto import tqdm | |||
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from schedulers.scheduling_euler_a import EulerAScheduler | 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| @@ -443,7 +443,7 @@ class Checkpointer: | |||
| 443 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 443 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
| 444 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 444 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 445 | 445 | ||
| 446 | scheduler = EulerAScheduler( | 446 | scheduler = EulerAncestralDiscreteScheduler( |
| 447 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 447 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 448 | ) | 448 | ) |
| 449 | 449 | ||
| @@ -715,7 +715,7 @@ def main(): | |||
| 715 | for i in range(0, len(missing_data), args.sample_batch_size) | 715 | for i in range(0, len(missing_data), args.sample_batch_size) |
| 716 | ] | 716 | ] |
| 717 | 717 | ||
| 718 | scheduler = EulerAScheduler( | 718 | scheduler = EulerAncestralDiscreteScheduler( |
| 719 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 719 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 720 | ) | 720 | ) |
| 721 | 721 | ||
| @@ -12,7 +12,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMSc | |||
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
| 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 17 | 17 | ||
| 18 | 18 | ||
| @@ -175,16 +175,8 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
| 175 | embeddings_dir = Path(embeddings_dir) | 175 | embeddings_dir = Path(embeddings_dir) |
| 176 | embeddings_dir.mkdir(parents=True, exist_ok=True) | 176 | embeddings_dir.mkdir(parents=True, exist_ok=True) |
| 177 | 177 | ||
| 178 | for file in embeddings_dir.iterdir(): | 178 | placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] |
| 179 | if file.is_file(): | 179 | tokenizer.add_tokens(placeholder_tokens) |
| 180 | placeholder_token = file.stem | ||
| 181 | |||
| 182 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | ||
| 183 | if num_added_tokens == 0: | ||
| 184 | raise ValueError( | ||
| 185 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
| 186 | " `placeholder_token` that is not already in the tokenizer." | ||
| 187 | ) | ||
| 188 | 180 | ||
| 189 | text_encoder.resize_token_embeddings(len(tokenizer)) | 181 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 190 | 182 | ||
| @@ -231,7 +223,7 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
| 231 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 223 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
| 232 | ) | 224 | ) |
| 233 | else: | 225 | else: |
| 234 | scheduler = EulerAScheduler( | 226 | scheduler = EulerAncestralDiscreteScheduler( |
| 235 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 227 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 236 | ) | 228 | ) |
| 237 | 229 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index e90528d..fc12355 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre | |||
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from schedulers.scheduling_euler_a import EulerAScheduler | 14 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
| 15 | from models.clip.prompt import PromptProcessor | 15 | from models.clip.prompt import PromptProcessor |
| 16 | 16 | ||
| 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| @@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 32 | text_encoder: CLIPTextModel, | 32 | text_encoder: CLIPTextModel, |
| 33 | tokenizer: CLIPTokenizer, | 33 | tokenizer: CLIPTokenizer, |
| 34 | unet: UNet2DConditionModel, | 34 | unet: UNet2DConditionModel, |
| 35 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], | 35 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], |
| 36 | **kwargs, | 36 | **kwargs, |
| 37 | ): | 37 | ): |
| 38 | super().__init__() | 38 | super().__init__() |
| @@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 225 | init_timestep = int(num_inference_steps * strength) + offset | 225 | init_timestep = int(num_inference_steps * strength) + offset |
| 226 | init_timestep = min(init_timestep, num_inference_steps) | 226 | init_timestep = min(init_timestep, num_inference_steps) |
| 227 | 227 | ||
| 228 | timesteps = self.scheduler.timesteps[-init_timestep] | 228 | if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): |
| 229 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 229 | timesteps = torch.tensor( |
| 230 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 231 | ) | ||
| 232 | else: | ||
| 233 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 234 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 230 | 235 | ||
| 231 | # add noise to latents using the timesteps | 236 | # add noise to latents using the timesteps |
| 232 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 237 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
| @@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 259 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 264 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
| 260 | # expand the latents if we are doing classifier free guidance | 265 | # expand the latents if we are doing classifier free guidance |
| 261 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 266 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 262 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 267 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) |
| 263 | 268 | ||
| 264 | noise_pred = None | 269 | # predict the noise residual |
| 265 | if isinstance(self.scheduler, EulerAScheduler): | 270 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| 266 | c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) | ||
| 267 | eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
| 268 | noise_pred = latent_model_input + eps * c_out | ||
| 269 | else: | ||
| 270 | # predict the noise residual | ||
| 271 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
| 272 | 271 | ||
| 273 | # perform guidance | 272 | # perform guidance |
| 274 | if do_classifier_free_guidance: | 273 | if do_classifier_free_guidance: |
| @@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 276 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 275 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 277 | 276 | ||
| 278 | # compute the previous noisy sample x_t -> x_t-1 | 277 | # compute the previous noisy sample x_t -> x_t-1 |
| 279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 278 | latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample |
| 280 | 279 | ||
| 281 | # scale and decode the image latents with vae | 280 | # scale and decode the image latents with vae |
| 282 | latents = 1 / 0.18215 * latents | 281 | latents = 1 / 0.18215 * latents |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py deleted file mode 100644 index c097a8a..0000000 --- a/schedulers/scheduling_euler_a.py +++ /dev/null | |||
| @@ -1,286 +0,0 @@ | |||
| 1 | from typing import Optional, Tuple, Union | ||
| 2 | |||
| 3 | import numpy as np | ||
| 4 | import torch | ||
| 5 | |||
| 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| 7 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
| 8 | |||
| 9 | |||
| 10 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | ||
| 11 | """ | ||
| 12 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | ||
| 13 | the VE column of Table 1 from [1] for reference. | ||
| 14 | |||
| 15 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." | ||
| 16 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic | ||
| 17 | differential equations." https://arxiv.org/abs/2011.13456 | ||
| 18 | |||
| 19 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
| 20 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
| 21 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
| 22 | [`~ConfigMixin.from_config`] functions. | ||
| 23 | |||
| 24 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of | ||
| 25 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the | ||
| 26 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. | ||
| 27 | |||
| 28 | Args: | ||
| 29 | sigma_min (`float`): minimum noise magnitude | ||
| 30 | sigma_max (`float`): maximum noise magnitude | ||
| 31 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
| 32 | A reasonable range is [1.000, 1.011]. | ||
| 33 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
| 34 | A reasonable range is [0, 100]. | ||
| 35 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). | ||
| 36 | A reasonable range is [0, 10]. | ||
| 37 | s_max (`float`): the end value of the sigma range where we add noise. | ||
| 38 | A reasonable range is [0.2, 80]. | ||
| 39 | |||
| 40 | """ | ||
| 41 | |||
| 42 | @register_to_config | ||
| 43 | def __init__( | ||
| 44 | self, | ||
| 45 | num_train_timesteps: int = 1000, | ||
| 46 | beta_start: float = 0.0001, | ||
| 47 | beta_end: float = 0.02, | ||
| 48 | beta_schedule: str = "linear", | ||
| 49 | trained_betas: Optional[np.ndarray] = None, | ||
| 50 | num_inference_steps=None, | ||
| 51 | device='cuda' | ||
| 52 | ): | ||
| 53 | if trained_betas is not None: | ||
| 54 | self.betas = torch.from_numpy(trained_betas).to(device) | ||
| 55 | if beta_schedule == "linear": | ||
| 56 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) | ||
| 57 | elif beta_schedule == "scaled_linear": | ||
| 58 | # this schedule is very specific to the latent diffusion model. | ||
| 59 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, | ||
| 60 | dtype=torch.float32, device=device) ** 2 | ||
| 61 | else: | ||
| 62 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
| 63 | |||
| 64 | self.device = device | ||
| 65 | |||
| 66 | self.alphas = 1.0 - self.betas | ||
| 67 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
| 68 | |||
| 69 | # standard deviation of the initial noise distribution | ||
| 70 | self.init_noise_sigma = 1.0 | ||
| 71 | |||
| 72 | # setable values | ||
| 73 | self.num_inference_steps = num_inference_steps | ||
| 74 | self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||
| 75 | # get sigmas | ||
| 76 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
| 77 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | ||
| 78 | |||
| 79 | # A# take number of steps as input | ||
| 80 | # A# store 1) number of steps 2) timesteps 3) schedule | ||
| 81 | |||
| 82 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): | ||
| 83 | """ | ||
| 84 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
| 85 | |||
| 86 | Args: | ||
| 87 | num_inference_steps (`int`): | ||
| 88 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
| 89 | """ | ||
| 90 | |||
| 91 | self.num_inference_steps = num_inference_steps | ||
| 92 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
| 93 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | ||
| 94 | self.timesteps = self.sigmas[:-1] | ||
| 95 | self.is_scale_input_called = False | ||
| 96 | |||
| 97 | def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor: | ||
| 98 | """ | ||
| 99 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | ||
| 100 | current timestep. | ||
| 101 | Args: | ||
| 102 | sample (`torch.FloatTensor`): input sample | ||
| 103 | timestep (`int`, optional): current timestep | ||
| 104 | Returns: | ||
| 105 | `torch.FloatTensor`: scaled input sample | ||
| 106 | """ | ||
| 107 | if isinstance(timestep, torch.Tensor): | ||
| 108 | timestep = timestep.to(self.timesteps.device) | ||
| 109 | if self.is_scale_input_called: | ||
| 110 | return sample | ||
| 111 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 112 | sigma = self.sigmas[step_index] | ||
| 113 | sample = sample * sigma | ||
| 114 | self.is_scale_input_called = True | ||
| 115 | return sample | ||
| 116 | |||
| 117 | def step( | ||
| 118 | self, | ||
| 119 | model_output: torch.FloatTensor, | ||
| 120 | timestep: Union[float, torch.FloatTensor], | ||
| 121 | sample: torch.FloatTensor, | ||
| 122 | generator: torch.Generator = None, | ||
| 123 | return_dict: bool = True, | ||
| 124 | ) -> Union[SchedulerOutput, Tuple]: | ||
| 125 | """ | ||
| 126 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
| 127 | process from the learned model outputs (most often the predicted noise). | ||
| 128 | |||
| 129 | Args: | ||
| 130 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| 131 | sigma_hat (`float`): TODO | ||
| 132 | sigma_prev (`float`): TODO | ||
| 133 | sample_hat (`torch.FloatTensor`): TODO | ||
| 134 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| 135 | |||
| 136 | EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). | ||
| 137 | Returns: | ||
| 138 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: | ||
| 139 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| 140 | returning a tuple, the first element is the sample tensor. | ||
| 141 | |||
| 142 | """ | ||
| 143 | if isinstance(timestep, torch.Tensor): | ||
| 144 | timestep = timestep.to(self.timesteps.device) | ||
| 145 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 146 | step_prev_index = step_index + 1 | ||
| 147 | |||
| 148 | s = self.sigmas[step_index] | ||
| 149 | s_prev = self.sigmas[step_prev_index] | ||
| 150 | latents = sample | ||
| 151 | |||
| 152 | sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) | ||
| 153 | d = self.to_d(latents, s, model_output) | ||
| 154 | dt = sigma_down - s | ||
| 155 | latents = latents + d * dt | ||
| 156 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, | ||
| 157 | generator=generator) * sigma_up | ||
| 158 | |||
| 159 | return SchedulerOutput(prev_sample=latents) | ||
| 160 | |||
| 161 | def step_correct( | ||
| 162 | self, | ||
| 163 | model_output: torch.FloatTensor, | ||
| 164 | sigma_hat: float, | ||
| 165 | sigma_prev: float, | ||
| 166 | sample_hat: torch.FloatTensor, | ||
| 167 | sample_prev: torch.FloatTensor, | ||
| 168 | derivative: torch.FloatTensor, | ||
| 169 | return_dict: bool = True, | ||
| 170 | ) -> Union[SchedulerOutput, Tuple]: | ||
| 171 | """ | ||
| 172 | Correct the predicted sample based on the output model_output of the network. TODO complete description | ||
| 173 | |||
| 174 | Args: | ||
| 175 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| 176 | sigma_hat (`float`): TODO | ||
| 177 | sigma_prev (`float`): TODO | ||
| 178 | sample_hat (`torch.FloatTensor`): TODO | ||
| 179 | sample_prev (`torch.FloatTensor`): TODO | ||
| 180 | derivative (`torch.FloatTensor`): TODO | ||
| 181 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| 182 | |||
| 183 | Returns: | ||
| 184 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO | ||
| 185 | |||
| 186 | """ | ||
| 187 | pred_original_sample = sample_prev + sigma_prev * model_output | ||
| 188 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev | ||
| 189 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) | ||
| 190 | |||
| 191 | if not return_dict: | ||
| 192 | return (sample_prev, derivative) | ||
| 193 | |||
| 194 | return SchedulerOutput(prev_sample=sample_prev) | ||
| 195 | |||
| 196 | def add_noise( | ||
| 197 | self, | ||
| 198 | original_samples: torch.FloatTensor, | ||
| 199 | noise: torch.FloatTensor, | ||
| 200 | timesteps: torch.FloatTensor, | ||
| 201 | ) -> torch.FloatTensor: | ||
| 202 | sigmas = self.sigmas.to(original_samples.device) | ||
| 203 | schedule_timesteps = self.timesteps.to(original_samples.device) | ||
| 204 | timesteps = timesteps.to(original_samples.device) | ||
| 205 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
| 206 | |||
| 207 | sigma = sigmas[step_indices].flatten() | ||
| 208 | while len(sigma.shape) < len(original_samples.shape): | ||
| 209 | sigma = sigma.unsqueeze(-1) | ||
| 210 | |||
| 211 | noisy_samples = original_samples + noise * sigma | ||
| 212 | self.is_scale_input_called = True | ||
| 213 | return noisy_samples | ||
| 214 | |||
| 215 | # from k_samplers sampling.py | ||
| 216 | |||
| 217 | def get_ancestral_step(self, sigma_from, sigma_to): | ||
| 218 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 219 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 220 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 221 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 222 | return sigma_down, sigma_up | ||
| 223 | |||
| 224 | def t_to_sigma(self, t, sigmas): | ||
| 225 | t = t.float() | ||
| 226 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 227 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 228 | |||
| 229 | def append_zero(self, x): | ||
| 230 | return torch.cat([x, x.new_zeros([1])]) | ||
| 231 | |||
| 232 | def get_sigmas(self, sigmas, n=None): | ||
| 233 | if n is None: | ||
| 234 | return self.append_zero(sigmas.flip(0)) | ||
| 235 | t_max = len(sigmas) - 1 # = 999 | ||
| 236 | device = self.device | ||
| 237 | t = torch.linspace(t_max, 0, n, device=device) | ||
| 238 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 239 | return self.append_zero(self.t_to_sigma(t, sigmas)) | ||
| 240 | |||
| 241 | # from k_samplers utils.py | ||
| 242 | def append_dims(self, x, target_dims): | ||
| 243 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 244 | dims_to_append = target_dims - x.ndim | ||
| 245 | if dims_to_append < 0: | ||
| 246 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 247 | return x[(...,) + (None,) * dims_to_append] | ||
| 248 | |||
| 249 | # from k_samplers sampling.py | ||
| 250 | def to_d(self, x, sigma, denoised): | ||
| 251 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 252 | return (x - denoised) / self.append_dims(sigma, x.ndim) | ||
| 253 | |||
| 254 | def get_scalings(self, sigma): | ||
| 255 | sigma_data = 1. | ||
| 256 | c_out = -sigma | ||
| 257 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 258 | return c_out, c_in | ||
| 259 | |||
| 260 | # DiscreteSchedule DS | ||
| 261 | def DSsigma_to_t(self, sigma, quantize=None): | ||
| 262 | # quantize = self.quantize if quantize is None else quantize | ||
| 263 | quantize = False | ||
| 264 | dists = torch.abs(sigma - self.DSsigmas[:, None]) | ||
| 265 | if quantize: | ||
| 266 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 267 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 268 | low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] | ||
| 269 | w = (low - sigma) / (low - high) | ||
| 270 | w = w.clamp(0, 1) | ||
| 271 | t = (1 - w) * low_idx + w * high_idx | ||
| 272 | return t.view(sigma.shape) | ||
| 273 | |||
| 274 | def prepare_input(self, latent_in, t, batch_size): | ||
| 275 | sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 | ||
| 276 | |||
| 277 | sigma_in = torch.cat([sigma] * 2 * batch_size) | ||
| 278 | # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) | ||
| 279 | # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) | ||
| 280 | c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] | ||
| 281 | |||
| 282 | sigma_in = self.DSsigma_to_t(sigma_in) | ||
| 283 | # s_in = latent_in.new_ones([latent_in.shape[0]]) | ||
| 284 | # sigma_in = sigma_in * s_in | ||
| 285 | |||
| 286 | return c_out, c_in, sigma_in | ||
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 0000000..3a2de68 --- /dev/null +++ b/schedulers/scheduling_euler_ancestral_discrete.py | |||
| @@ -0,0 +1,192 @@ | |||
| 1 | # Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. | ||
| 2 | # | ||
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | # you may not use this file except in compliance with the License. | ||
| 5 | # You may obtain a copy of the License at | ||
| 6 | # | ||
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | # | ||
| 9 | # Unless required by applicable law or agreed to in writing, software | ||
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | # See the License for the specific language governing permissions and | ||
| 13 | # limitations under the License. | ||
| 14 | |||
| 15 | from typing import Optional, Tuple, Union | ||
| 16 | |||
| 17 | import numpy as np | ||
| 18 | import torch | ||
| 19 | |||
| 20 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| 21 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
| 22 | |||
| 23 | |||
| 24 | class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
| 25 | """ | ||
| 26 | Ancestral sampling with Euler method steps. | ||
| 27 | for discrete beta schedules. Based on the original k-diffusion implementation by | ||
| 28 | Katherine Crowson: | ||
| 29 | https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 | ||
| 30 | |||
| 31 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
| 32 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
| 33 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
| 34 | [`~ConfigMixin.from_config`] functions. | ||
| 35 | |||
| 36 | Args: | ||
| 37 | num_train_timesteps (`int`): number of diffusion steps used to train the model. | ||
| 38 | beta_start (`float`): the starting `beta` value of inference. | ||
| 39 | beta_end (`float`): the final `beta` value. | ||
| 40 | beta_schedule (`str`): | ||
| 41 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | ||
| 42 | `linear` or `scaled_linear`. | ||
| 43 | trained_betas (`np.ndarray`, optional): | ||
| 44 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
| 45 | options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, | ||
| 46 | `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. | ||
| 47 | tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. | ||
| 48 | |||
| 49 | """ | ||
| 50 | |||
| 51 | @register_to_config | ||
| 52 | def __init__( | ||
| 53 | self, | ||
| 54 | num_train_timesteps: int = 1000, | ||
| 55 | beta_start: float = 0.00085, # sensible defaults | ||
| 56 | beta_end: float = 0.012, | ||
| 57 | beta_schedule: str = "linear", | ||
| 58 | trained_betas: Optional[np.ndarray] = None, | ||
| 59 | ): | ||
| 60 | if trained_betas is not None: | ||
| 61 | self.betas = torch.from_numpy(trained_betas) | ||
| 62 | elif beta_schedule == "linear": | ||
| 63 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
| 64 | elif beta_schedule == "scaled_linear": | ||
| 65 | # this schedule is very specific to the latent diffusion model. | ||
| 66 | self.betas = ( | ||
| 67 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
| 68 | ) | ||
| 69 | else: | ||
| 70 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
| 71 | |||
| 72 | self.alphas = 1.0 - self.betas | ||
| 73 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
| 74 | |||
| 75 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
| 76 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) | ||
| 77 | self.sigmas = torch.from_numpy(sigmas) | ||
| 78 | |||
| 79 | self.init_noise_sigma = None | ||
| 80 | |||
| 81 | # setable values | ||
| 82 | self.num_inference_steps = None | ||
| 83 | timesteps = np.arange(0, num_train_timesteps)[::-1].copy() | ||
| 84 | self.timesteps = torch.from_numpy(timesteps) | ||
| 85 | self.derivatives = [] | ||
| 86 | self.is_scale_input_called = False | ||
| 87 | |||
| 88 | def scale_model_input( | ||
| 89 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] | ||
| 90 | ) -> torch.FloatTensor: | ||
| 91 | """ | ||
| 92 | Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. | ||
| 93 | |||
| 94 | Args: | ||
| 95 | sample (`torch.FloatTensor`): input sample | ||
| 96 | timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | ||
| 97 | |||
| 98 | Returns: | ||
| 99 | `torch.FloatTensor`: scaled input sample | ||
| 100 | """ | ||
| 101 | sigma = self.sigmas[step_index] | ||
| 102 | sample = sample / ((sigma**2 + 1) ** 0.5) | ||
| 103 | return sample | ||
| 104 | |||
| 105 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | ||
| 106 | """ | ||
| 107 | Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
| 108 | |||
| 109 | Args: | ||
| 110 | num_inference_steps (`int`): | ||
| 111 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
| 112 | """ | ||
| 113 | self.num_inference_steps = num_inference_steps | ||
| 114 | self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) | ||
| 115 | |||
| 116 | low_idx = np.floor(self.timesteps).astype(int) | ||
| 117 | high_idx = np.ceil(self.timesteps).astype(int) | ||
| 118 | frac = np.mod(self.timesteps, 1.0) | ||
| 119 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
| 120 | sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] | ||
| 121 | sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| 122 | self.sigmas = torch.from_numpy(sigmas) | ||
| 123 | self.timesteps = torch.from_numpy(self.timesteps) | ||
| 124 | self.init_noise_sigma = self.sigmas[0] | ||
| 125 | self.derivatives = [] | ||
| 126 | |||
| 127 | def step( | ||
| 128 | self, | ||
| 129 | model_output: Union[torch.FloatTensor, np.ndarray], | ||
| 130 | timestep: Union[float, torch.FloatTensor], | ||
| 131 | step_index: Union[int, torch.IntTensor], | ||
| 132 | sample: Union[torch.FloatTensor, np.ndarray], | ||
| 133 | return_dict: bool = True, | ||
| 134 | ) -> Union[SchedulerOutput, Tuple]: | ||
| 135 | """ | ||
| 136 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
| 137 | process from the learned model outputs (most often the predicted noise). | ||
| 138 | |||
| 139 | Args: | ||
| 140 | model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. | ||
| 141 | timestep (`int`): current discrete timestep in the diffusion chain. | ||
| 142 | sample (`torch.FloatTensor` or `np.ndarray`): | ||
| 143 | current instance of sample being created by diffusion process. | ||
| 144 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
| 145 | |||
| 146 | Returns: | ||
| 147 | [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: | ||
| 148 | [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
| 149 | returning a tuple, the first element is the sample tensor. | ||
| 150 | |||
| 151 | """ | ||
| 152 | sigma = self.sigmas[step_index] | ||
| 153 | |||
| 154 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
| 155 | pred_original_sample = sample - sigma * model_output | ||
| 156 | sigma_from = self.sigmas[step_index] | ||
| 157 | sigma_to = self.sigmas[step_index + 1] | ||
| 158 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 159 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 160 | # 2. Convert to an ODE derivative | ||
| 161 | derivative = (sample - pred_original_sample) / sigma | ||
| 162 | self.derivatives.append(derivative) | ||
| 163 | |||
| 164 | dt = sigma_down - sigma | ||
| 165 | |||
| 166 | prev_sample = sample + derivative * dt | ||
| 167 | |||
| 168 | prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up | ||
| 169 | |||
| 170 | if not return_dict: | ||
| 171 | return (prev_sample,) | ||
| 172 | |||
| 173 | return SchedulerOutput(prev_sample=prev_sample) | ||
| 174 | |||
| 175 | def add_noise( | ||
| 176 | self, | ||
| 177 | original_samples: torch.FloatTensor, | ||
| 178 | noise: torch.FloatTensor, | ||
| 179 | timesteps: torch.IntTensor, | ||
| 180 | ) -> torch.FloatTensor: | ||
| 181 | # Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
| 182 | self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
| 183 | self.timesteps = self.timesteps.to(original_samples.device) | ||
| 184 | sigma = self.sigmas[timesteps].flatten() | ||
| 185 | while len(sigma.shape) < len(original_samples.shape): | ||
| 186 | sigma = sigma.unsqueeze(-1) | ||
| 187 | |||
| 188 | noisy_samples = original_samples + noise * sigma | ||
| 189 | return noisy_samples | ||
| 190 | |||
| 191 | def __len__(self): | ||
| 192 | return self.config.num_train_timesteps | ||
diff --git a/textual_inversion.py b/textual_inversion.py index bcdfd3a..dd7c3bd 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -22,7 +22,7 @@ from tqdm.auto import tqdm | |||
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler |
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 28 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| @@ -398,7 +398,7 @@ class Checkpointer: | |||
| 398 | samples_path = Path(self.output_dir).joinpath("samples") | 398 | samples_path = Path(self.output_dir).joinpath("samples") |
| 399 | 399 | ||
| 400 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 400 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 401 | scheduler = EulerAScheduler( | 401 | scheduler = EulerAncestralDiscreteScheduler( |
| 402 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 402 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 403 | ) | 403 | ) |
| 404 | 404 | ||
| @@ -639,7 +639,7 @@ def main(): | |||
| 639 | batched_data = [missing_data[i:i+args.sample_batch_size] | 639 | batched_data = [missing_data[i:i+args.sample_batch_size] |
| 640 | for i in range(0, len(missing_data), args.sample_batch_size)] | 640 | for i in range(0, len(missing_data), args.sample_batch_size)] |
| 641 | 641 | ||
| 642 | scheduler = EulerAScheduler( | 642 | scheduler = EulerAncestralDiscreteScheduler( |
| 643 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 643 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 644 | ) | 644 | ) |
| 645 | 645 | ||
