summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-26 11:11:33 +0200
committerVolpeon <git@volpeon.ink>2022-10-26 11:11:33 +0200
commit49463992f48ec25f2ea31b220a6cedac3466467a (patch)
treea58f40e558c14403dbeda687708ef334371694b8
parentAdvanced datasets (diff)
downloadtextual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.tar.gz
textual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.tar.bz2
textual-inversion-diff-49463992f48ec25f2ea31b220a6cedac3466467a.zip
New Euler_a scheduler
-rw-r--r--dreambooth.py6
-rw-r--r--infer.py16
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
-rw-r--r--schedulers/scheduling_euler_a.py286
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py192
-rw-r--r--textual_inversion.py6
6 files changed, 215 insertions, 318 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 2c24908..a181293 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -23,7 +23,7 @@ from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from schedulers.scheduling_euler_a import EulerAScheduler 26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -443,7 +443,7 @@ class Checkpointer:
443 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 443 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
444 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 444 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
445 445
446 scheduler = EulerAScheduler( 446 scheduler = EulerAncestralDiscreteScheduler(
447 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 447 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
448 ) 448 )
449 449
@@ -715,7 +715,7 @@ def main():
715 for i in range(0, len(missing_data), args.sample_batch_size) 715 for i in range(0, len(missing_data), args.sample_batch_size)
716 ] 716 ]
717 717
718 scheduler = EulerAScheduler( 718 scheduler = EulerAncestralDiscreteScheduler(
719 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 719 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
720 ) 720 )
721 721
diff --git a/infer.py b/infer.py
index 01010eb..ac05955 100644
--- a/infer.py
+++ b/infer.py
@@ -12,7 +12,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMSc
12from transformers import CLIPTextModel, CLIPTokenizer 12from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 13from slugify import slugify
14 14
15from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
17 17
18 18
@@ -175,16 +175,8 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
175 embeddings_dir = Path(embeddings_dir) 175 embeddings_dir = Path(embeddings_dir)
176 embeddings_dir.mkdir(parents=True, exist_ok=True) 176 embeddings_dir.mkdir(parents=True, exist_ok=True)
177 177
178 for file in embeddings_dir.iterdir(): 178 placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()]
179 if file.is_file(): 179 tokenizer.add_tokens(placeholder_tokens)
180 placeholder_token = file.stem
181
182 num_added_tokens = tokenizer.add_tokens(placeholder_token)
183 if num_added_tokens == 0:
184 raise ValueError(
185 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
186 " `placeholder_token` that is not already in the tokenizer."
187 )
188 180
189 text_encoder.resize_token_embeddings(len(tokenizer)) 181 text_encoder.resize_token_embeddings(len(tokenizer))
190 182
@@ -231,7 +223,7 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
231 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 223 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
232 ) 224 )
233 else: 225 else:
234 scheduler = EulerAScheduler( 226 scheduler = EulerAncestralDiscreteScheduler(
235 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 227 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
236 ) 228 )
237 229
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index e90528d..fc12355 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler 14from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
15from models.clip.prompt import PromptProcessor 15from models.clip.prompt import PromptProcessor
16 16
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -32,7 +32,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
32 text_encoder: CLIPTextModel, 32 text_encoder: CLIPTextModel,
33 tokenizer: CLIPTokenizer, 33 tokenizer: CLIPTokenizer,
34 unet: UNet2DConditionModel, 34 unet: UNet2DConditionModel,
35 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], 35 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler],
36 **kwargs, 36 **kwargs,
37 ): 37 ):
38 super().__init__() 38 super().__init__()
@@ -225,8 +225,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
225 init_timestep = int(num_inference_steps * strength) + offset 225 init_timestep = int(num_inference_steps * strength) + offset
226 init_timestep = min(init_timestep, num_inference_steps) 226 init_timestep = min(init_timestep, num_inference_steps)
227 227
228 timesteps = self.scheduler.timesteps[-init_timestep] 228 if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler):
229 timesteps = torch.tensor([timesteps] * batch_size, device=self.device) 229 timesteps = torch.tensor(
230 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
231 )
232 else:
233 timesteps = self.scheduler.timesteps[-init_timestep]
234 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
230 235
231 # add noise to latents using the timesteps 236 # add noise to latents using the timesteps
232 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) 237 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
@@ -259,16 +264,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
259 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 264 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
260 # expand the latents if we are doing classifier free guidance 265 # expand the latents if we are doing classifier free guidance
261 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 266 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
262 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 267 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i)
263 268
264 noise_pred = None 269 # predict the noise residual
265 if isinstance(self.scheduler, EulerAScheduler): 270 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
266 c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
267 eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample
268 noise_pred = latent_model_input + eps * c_out
269 else:
270 # predict the noise residual
271 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
272 271
273 # perform guidance 272 # perform guidance
274 if do_classifier_free_guidance: 273 if do_classifier_free_guidance:
@@ -276,7 +275,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
276 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 275 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
277 276
278 # compute the previous noisy sample x_t -> x_t-1 277 # compute the previous noisy sample x_t -> x_t-1
279 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 278 latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample
280 279
281 # scale and decode the image latents with vae 280 # scale and decode the image latents with vae
282 latents = 1 / 0.18215 * latents 281 latents = 1 / 0.18215 * latents
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
deleted file mode 100644
index c097a8a..0000000
--- a/schedulers/scheduling_euler_a.py
+++ /dev/null
@@ -1,286 +0,0 @@
1from typing import Optional, Tuple, Union
2
3import numpy as np
4import torch
5
6from diffusers.configuration_utils import ConfigMixin, register_to_config
7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
8
9
10class EulerAScheduler(SchedulerMixin, ConfigMixin):
11 """
12 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
13 the VE column of Table 1 from [1] for reference.
14
15 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
16 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
17 differential equations." https://arxiv.org/abs/2011.13456
18
19 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
20 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
21 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
22 [`~ConfigMixin.from_config`] functions.
23
24 For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
25 Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
26 optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
27
28 Args:
29 sigma_min (`float`): minimum noise magnitude
30 sigma_max (`float`): maximum noise magnitude
31 s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
32 A reasonable range is [1.000, 1.011].
33 s_churn (`float`): the parameter controlling the overall amount of stochasticity.
34 A reasonable range is [0, 100].
35 s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
36 A reasonable range is [0, 10].
37 s_max (`float`): the end value of the sigma range where we add noise.
38 A reasonable range is [0.2, 80].
39
40 """
41
42 @register_to_config
43 def __init__(
44 self,
45 num_train_timesteps: int = 1000,
46 beta_start: float = 0.0001,
47 beta_end: float = 0.02,
48 beta_schedule: str = "linear",
49 trained_betas: Optional[np.ndarray] = None,
50 num_inference_steps=None,
51 device='cuda'
52 ):
53 if trained_betas is not None:
54 self.betas = torch.from_numpy(trained_betas).to(device)
55 if beta_schedule == "linear":
56 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device)
57 elif beta_schedule == "scaled_linear":
58 # this schedule is very specific to the latent diffusion model.
59 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps,
60 dtype=torch.float32, device=device) ** 2
61 else:
62 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
63
64 self.device = device
65
66 self.alphas = 1.0 - self.betas
67 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
68
69 # standard deviation of the initial noise distribution
70 self.init_noise_sigma = 1.0
71
72 # setable values
73 self.num_inference_steps = num_inference_steps
74 self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
75 # get sigmas
76 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
77 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
78
79 # A# take number of steps as input
80 # A# store 1) number of steps 2) timesteps 3) schedule
81
82 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
83 """
84 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
85
86 Args:
87 num_inference_steps (`int`):
88 the number of diffusion steps used when generating samples with a pre-trained model.
89 """
90
91 self.num_inference_steps = num_inference_steps
92 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
93 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
94 self.timesteps = self.sigmas[:-1]
95 self.is_scale_input_called = False
96
97 def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor:
98 """
99 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
100 current timestep.
101 Args:
102 sample (`torch.FloatTensor`): input sample
103 timestep (`int`, optional): current timestep
104 Returns:
105 `torch.FloatTensor`: scaled input sample
106 """
107 if isinstance(timestep, torch.Tensor):
108 timestep = timestep.to(self.timesteps.device)
109 if self.is_scale_input_called:
110 return sample
111 step_index = (self.timesteps == timestep).nonzero().item()
112 sigma = self.sigmas[step_index]
113 sample = sample * sigma
114 self.is_scale_input_called = True
115 return sample
116
117 def step(
118 self,
119 model_output: torch.FloatTensor,
120 timestep: Union[float, torch.FloatTensor],
121 sample: torch.FloatTensor,
122 generator: torch.Generator = None,
123 return_dict: bool = True,
124 ) -> Union[SchedulerOutput, Tuple]:
125 """
126 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
127 process from the learned model outputs (most often the predicted noise).
128
129 Args:
130 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
131 sigma_hat (`float`): TODO
132 sigma_prev (`float`): TODO
133 sample_hat (`torch.FloatTensor`): TODO
134 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
135
136 EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check).
137 Returns:
138 [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`:
139 [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When
140 returning a tuple, the first element is the sample tensor.
141
142 """
143 if isinstance(timestep, torch.Tensor):
144 timestep = timestep.to(self.timesteps.device)
145 step_index = (self.timesteps == timestep).nonzero().item()
146 step_prev_index = step_index + 1
147
148 s = self.sigmas[step_index]
149 s_prev = self.sigmas[step_prev_index]
150 latents = sample
151
152 sigma_down, sigma_up = self.get_ancestral_step(s, s_prev)
153 d = self.to_d(latents, s, model_output)
154 dt = sigma_down - s
155 latents = latents + d * dt
156 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
157 generator=generator) * sigma_up
158
159 return SchedulerOutput(prev_sample=latents)
160
161 def step_correct(
162 self,
163 model_output: torch.FloatTensor,
164 sigma_hat: float,
165 sigma_prev: float,
166 sample_hat: torch.FloatTensor,
167 sample_prev: torch.FloatTensor,
168 derivative: torch.FloatTensor,
169 return_dict: bool = True,
170 ) -> Union[SchedulerOutput, Tuple]:
171 """
172 Correct the predicted sample based on the output model_output of the network. TODO complete description
173
174 Args:
175 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
176 sigma_hat (`float`): TODO
177 sigma_prev (`float`): TODO
178 sample_hat (`torch.FloatTensor`): TODO
179 sample_prev (`torch.FloatTensor`): TODO
180 derivative (`torch.FloatTensor`): TODO
181 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
182
183 Returns:
184 prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
185
186 """
187 pred_original_sample = sample_prev + sigma_prev * model_output
188 derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
189 sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
190
191 if not return_dict:
192 return (sample_prev, derivative)
193
194 return SchedulerOutput(prev_sample=sample_prev)
195
196 def add_noise(
197 self,
198 original_samples: torch.FloatTensor,
199 noise: torch.FloatTensor,
200 timesteps: torch.FloatTensor,
201 ) -> torch.FloatTensor:
202 sigmas = self.sigmas.to(original_samples.device)
203 schedule_timesteps = self.timesteps.to(original_samples.device)
204 timesteps = timesteps.to(original_samples.device)
205 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
206
207 sigma = sigmas[step_indices].flatten()
208 while len(sigma.shape) < len(original_samples.shape):
209 sigma = sigma.unsqueeze(-1)
210
211 noisy_samples = original_samples + noise * sigma
212 self.is_scale_input_called = True
213 return noisy_samples
214
215 # from k_samplers sampling.py
216
217 def get_ancestral_step(self, sigma_from, sigma_to):
218 """Calculates the noise level (sigma_down) to step down to and the amount
219 of noise to add (sigma_up) when doing an ancestral sampling step."""
220 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
221 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
222 return sigma_down, sigma_up
223
224 def t_to_sigma(self, t, sigmas):
225 t = t.float()
226 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
227 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
228
229 def append_zero(self, x):
230 return torch.cat([x, x.new_zeros([1])])
231
232 def get_sigmas(self, sigmas, n=None):
233 if n is None:
234 return self.append_zero(sigmas.flip(0))
235 t_max = len(sigmas) - 1 # = 999
236 device = self.device
237 t = torch.linspace(t_max, 0, n, device=device)
238 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
239 return self.append_zero(self.t_to_sigma(t, sigmas))
240
241 # from k_samplers utils.py
242 def append_dims(self, x, target_dims):
243 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
244 dims_to_append = target_dims - x.ndim
245 if dims_to_append < 0:
246 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
247 return x[(...,) + (None,) * dims_to_append]
248
249 # from k_samplers sampling.py
250 def to_d(self, x, sigma, denoised):
251 """Converts a denoiser output to a Karras ODE derivative."""
252 return (x - denoised) / self.append_dims(sigma, x.ndim)
253
254 def get_scalings(self, sigma):
255 sigma_data = 1.
256 c_out = -sigma
257 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
258 return c_out, c_in
259
260 # DiscreteSchedule DS
261 def DSsigma_to_t(self, sigma, quantize=None):
262 # quantize = self.quantize if quantize is None else quantize
263 quantize = False
264 dists = torch.abs(sigma - self.DSsigmas[:, None])
265 if quantize:
266 return torch.argmin(dists, dim=0).view(sigma.shape)
267 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
268 low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
269 w = (low - sigma) / (low - high)
270 w = w.clamp(0, 1)
271 t = (1 - w) * low_idx + w * high_idx
272 return t.view(sigma.shape)
273
274 def prepare_input(self, latent_in, t, batch_size):
275 sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1
276
277 sigma_in = torch.cat([sigma] * 2 * batch_size)
278 # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
279 # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
280 c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
281
282 sigma_in = self.DSsigma_to_t(sigma_in)
283 # s_in = latent_in.new_ones([latent_in.shape[0]])
284 # sigma_in = sigma_in * s_in
285
286 return c_out, c_in, sigma_in
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
new file mode 100644
index 0000000..3a2de68
--- /dev/null
+++ b/schedulers/scheduling_euler_ancestral_discrete.py
@@ -0,0 +1,192 @@
1# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, Tuple, Union
16
17import numpy as np
18import torch
19
20from diffusers.configuration_utils import ConfigMixin, register_to_config
21from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
22
23
24class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
25 """
26 Ancestral sampling with Euler method steps.
27 for discrete beta schedules. Based on the original k-diffusion implementation by
28 Katherine Crowson:
29 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
30
31 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
32 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
33 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
34 [`~ConfigMixin.from_config`] functions.
35
36 Args:
37 num_train_timesteps (`int`): number of diffusion steps used to train the model.
38 beta_start (`float`): the starting `beta` value of inference.
39 beta_end (`float`): the final `beta` value.
40 beta_schedule (`str`):
41 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
42 `linear` or `scaled_linear`.
43 trained_betas (`np.ndarray`, optional):
44 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
45 options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
46 `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
47 tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
48
49 """
50
51 @register_to_config
52 def __init__(
53 self,
54 num_train_timesteps: int = 1000,
55 beta_start: float = 0.00085, # sensible defaults
56 beta_end: float = 0.012,
57 beta_schedule: str = "linear",
58 trained_betas: Optional[np.ndarray] = None,
59 ):
60 if trained_betas is not None:
61 self.betas = torch.from_numpy(trained_betas)
62 elif beta_schedule == "linear":
63 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
64 elif beta_schedule == "scaled_linear":
65 # this schedule is very specific to the latent diffusion model.
66 self.betas = (
67 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
68 )
69 else:
70 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
71
72 self.alphas = 1.0 - self.betas
73 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
74
75 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
76 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
77 self.sigmas = torch.from_numpy(sigmas)
78
79 self.init_noise_sigma = None
80
81 # setable values
82 self.num_inference_steps = None
83 timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
84 self.timesteps = torch.from_numpy(timesteps)
85 self.derivatives = []
86 self.is_scale_input_called = False
87
88 def scale_model_input(
89 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor]
90 ) -> torch.FloatTensor:
91 """
92 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
93
94 Args:
95 sample (`torch.FloatTensor`): input sample
96 timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
97
98 Returns:
99 `torch.FloatTensor`: scaled input sample
100 """
101 sigma = self.sigmas[step_index]
102 sample = sample / ((sigma**2 + 1) ** 0.5)
103 return sample
104
105 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
106 """
107 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
108
109 Args:
110 num_inference_steps (`int`):
111 the number of diffusion steps used when generating samples with a pre-trained model.
112 """
113 self.num_inference_steps = num_inference_steps
114 self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
115
116 low_idx = np.floor(self.timesteps).astype(int)
117 high_idx = np.ceil(self.timesteps).astype(int)
118 frac = np.mod(self.timesteps, 1.0)
119 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
120 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
121 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
122 self.sigmas = torch.from_numpy(sigmas)
123 self.timesteps = torch.from_numpy(self.timesteps)
124 self.init_noise_sigma = self.sigmas[0]
125 self.derivatives = []
126
127 def step(
128 self,
129 model_output: Union[torch.FloatTensor, np.ndarray],
130 timestep: Union[float, torch.FloatTensor],
131 step_index: Union[int, torch.IntTensor],
132 sample: Union[torch.FloatTensor, np.ndarray],
133 return_dict: bool = True,
134 ) -> Union[SchedulerOutput, Tuple]:
135 """
136 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
137 process from the learned model outputs (most often the predicted noise).
138
139 Args:
140 model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
141 timestep (`int`): current discrete timestep in the diffusion chain.
142 sample (`torch.FloatTensor` or `np.ndarray`):
143 current instance of sample being created by diffusion process.
144 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
145
146 Returns:
147 [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
148 [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
149 returning a tuple, the first element is the sample tensor.
150
151 """
152 sigma = self.sigmas[step_index]
153
154 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
155 pred_original_sample = sample - sigma * model_output
156 sigma_from = self.sigmas[step_index]
157 sigma_to = self.sigmas[step_index + 1]
158 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
159 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
160 # 2. Convert to an ODE derivative
161 derivative = (sample - pred_original_sample) / sigma
162 self.derivatives.append(derivative)
163
164 dt = sigma_down - sigma
165
166 prev_sample = sample + derivative * dt
167
168 prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up
169
170 if not return_dict:
171 return (prev_sample,)
172
173 return SchedulerOutput(prev_sample=prev_sample)
174
175 def add_noise(
176 self,
177 original_samples: torch.FloatTensor,
178 noise: torch.FloatTensor,
179 timesteps: torch.IntTensor,
180 ) -> torch.FloatTensor:
181 # Make sure sigmas and timesteps have the same device and dtype as original_samples
182 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
183 self.timesteps = self.timesteps.to(original_samples.device)
184 sigma = self.sigmas[timesteps].flatten()
185 while len(sigma.shape) < len(original_samples.shape):
186 sigma = sigma.unsqueeze(-1)
187
188 noisy_samples = original_samples + noise * sigma
189 return noisy_samples
190
191 def __len__(self):
192 return self.config.num_train_timesteps
diff --git a/textual_inversion.py b/textual_inversion.py
index bcdfd3a..dd7c3bd 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -22,7 +22,7 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -398,7 +398,7 @@ class Checkpointer:
398 samples_path = Path(self.output_dir).joinpath("samples") 398 samples_path = Path(self.output_dir).joinpath("samples")
399 399
400 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 400 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
401 scheduler = EulerAScheduler( 401 scheduler = EulerAncestralDiscreteScheduler(
402 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 402 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
403 ) 403 )
404 404
@@ -639,7 +639,7 @@ def main():
639 batched_data = [missing_data[i:i+args.sample_batch_size] 639 batched_data = [missing_data[i:i+args.sample_batch_size]
640 for i in range(0, len(missing_data), args.sample_batch_size)] 640 for i in range(0, len(missing_data), args.sample_batch_size)]
641 641
642 scheduler = EulerAScheduler( 642 scheduler = EulerAncestralDiscreteScheduler(
643 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 643 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
644 ) 644 )
645 645