summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml1
-rw-r--r--infer.py7
-rw-r--r--schedulers/scheduling_deis_multistep.py500
-rw-r--r--train_dreambooth.py38
-rw-r--r--train_lora.py38
-rw-r--r--train_ti.py38
-rw-r--r--training/functional.py9
7 files changed, 592 insertions, 39 deletions
diff --git a/environment.yaml b/environment.yaml
index 325644f..57624a3 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -18,6 +18,7 @@ dependencies:
18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
19 - accelerate==0.16.0 19 - accelerate==0.16.0
20 - bitsandbytes==0.37.0 20 - bitsandbytes==0.37.0
21 - lion-pytorch==0.0.6
21 - python-slugify>=6.1.2 22 - python-slugify>=6.1.2
22 - safetensors==0.2.8 23 - safetensors==0.2.8
23 - setuptools==65.6.3 24 - setuptools==65.6.3
diff --git a/infer.py b/infer.py
index 51cf3a7..8910e68 100644
--- a/infer.py
+++ b/infer.py
@@ -23,6 +23,7 @@ from diffusers import (
23 EulerAncestralDiscreteScheduler, 23 EulerAncestralDiscreteScheduler,
24 KDPM2DiscreteScheduler, 24 KDPM2DiscreteScheduler,
25 KDPM2AncestralDiscreteScheduler, 25 KDPM2AncestralDiscreteScheduler,
26 DEISMultistepScheduler,
26 UniPCMultistepScheduler 27 UniPCMultistepScheduler
27) 28)
28from transformers import CLIPTextModel 29from transformers import CLIPTextModel
@@ -126,13 +127,13 @@ def create_cmd_parser():
126 parser.add_argument( 127 parser.add_argument(
127 "--scheduler", 128 "--scheduler",
128 type=str, 129 type=str,
129 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"], 130 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"],
130 ) 131 )
131 parser.add_argument( 132 parser.add_argument(
132 "--subscheduler", 133 "--subscheduler",
133 type=str, 134 type=str,
134 default=None, 135 default=None,
135 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 136 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"],
136 ) 137 )
137 parser.add_argument( 138 parser.add_argument(
138 "--template", 139 "--template",
@@ -252,6 +253,8 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
252 return KDPM2DiscreteScheduler.from_config(config) 253 return KDPM2DiscreteScheduler.from_config(config)
253 elif scheduler == "kdpm2_a": 254 elif scheduler == "kdpm2_a":
254 return KDPM2AncestralDiscreteScheduler.from_config(config) 255 return KDPM2AncestralDiscreteScheduler.from_config(config)
256 elif scheduler == "deis":
257 return DEISMultistepScheduler.from_config(config)
255 elif scheduler == "unipc": 258 elif scheduler == "unipc":
256 if subscheduler is None: 259 if subscheduler is None:
257 return UniPCMultistepScheduler.from_config(config) 260 return UniPCMultistepScheduler.from_config(config)
diff --git a/schedulers/scheduling_deis_multistep.py b/schedulers/scheduling_deis_multistep.py
new file mode 100644
index 0000000..ea1281e
--- /dev/null
+++ b/schedulers/scheduling_deis_multistep.py
@@ -0,0 +1,500 @@
1# Copyright 2022 FLAIR Lab and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info
16# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
17
18import math
19from typing import List, Optional, Tuple, Union
20
21import numpy as np
22import torch
23
24from diffusers.configuration_utils import ConfigMixin, register_to_config
25from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
26
27
28def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
29 """
30 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
31 (1-beta) over time from t = [0,1].
32
33 Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
34 to that part of the diffusion process.
35
36
37 Args:
38 num_diffusion_timesteps (`int`): the number of betas to produce.
39 max_beta (`float`): the maximum beta to use; use values lower than 1 to
40 prevent singularities.
41
42 Returns:
43 betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
44 """
45
46 def alpha_bar(time_step):
47 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
48
49 betas = []
50 for i in range(num_diffusion_timesteps):
51 t1 = i / num_diffusion_timesteps
52 t2 = (i + 1) / num_diffusion_timesteps
53 betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
54 return torch.tensor(betas, dtype=torch.float32)
55
56
57class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
58 """
59 DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the
60 polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification
61 enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More
62 variants of DEIS can be found in https://github.com/qsh-zh/deis.
63
64 Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1`
65 reduces to DDIM.
66
67 We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
68 diffusion models, you can set `thresholding=True` to use the dynamic thresholding.
69
70 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
71 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
72 [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
73 [`~SchedulerMixin.from_pretrained`] functions.
74
75 Args:
76 num_train_timesteps (`int`): number of diffusion steps used to train the model.
77 beta_start (`float`): the starting `beta` value of inference.
78 beta_end (`float`): the final `beta` value.
79 beta_schedule (`str`):
80 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
81 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
82 trained_betas (`np.ndarray`, optional):
83 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
84 solver_order (`int`, default `2`):
85 the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and
86 `solver_order=3` for unconditional sampling.
87 prediction_type (`str`, default `epsilon`):
88 indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
89 or `v-prediction`.
90 thresholding (`bool`, default `False`):
91 whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
92 Note that the thresholding method is unsuitable for latent-space diffusion models (such as
93 stable-diffusion).
94 dynamic_thresholding_ratio (`float`, default `0.995`):
95 the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
96 (https://arxiv.org/abs/2205.11487).
97 sample_max_value (`float`, default `1.0`):
98 the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
99 algorithm_type (`str`, default `deis`):
100 the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
101 the future
102 lower_order_final (`bool`, default `True`):
103 whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
104 find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
105
106 """
107
108 _compatibles = [e.name for e in KarrasDiffusionSchedulers]
109 order = 1
110
111 @register_to_config
112 def __init__(
113 self,
114 num_train_timesteps: int = 1000,
115 beta_start: float = 0.0001,
116 beta_end: float = 0.02,
117 beta_schedule: str = "linear",
118 trained_betas: Optional[np.ndarray] = None,
119 solver_order: int = 2,
120 prediction_type: str = "epsilon",
121 thresholding: bool = False,
122 dynamic_thresholding_ratio: float = 0.995,
123 sample_max_value: float = 1.0,
124 algorithm_type: str = "deis",
125 solver_type: str = "logrho",
126 lower_order_final: bool = True,
127 ):
128 if trained_betas is not None:
129 self.betas = torch.tensor(trained_betas, dtype=torch.float32)
130 elif beta_schedule == "linear":
131 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
132 elif beta_schedule == "scaled_linear":
133 # this schedule is very specific to the latent diffusion model.
134 self.betas = (
135 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
136 )
137 elif beta_schedule == "squaredcos_cap_v2":
138 # Glide cosine schedule
139 self.betas = betas_for_alpha_bar(num_train_timesteps)
140 else:
141 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
142
143 self.alphas = 1.0 - self.betas
144 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
145 # Currently we only support VP-type noise schedule
146 self.alpha_t = torch.sqrt(self.alphas_cumprod)
147 self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
148 self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
149
150 # standard deviation of the initial noise distribution
151 self.init_noise_sigma = 1.0
152
153 # settings for DEIS
154 if algorithm_type not in ["deis"]:
155 if algorithm_type in ["dpmsolver", "dpmsolver++"]:
156 algorithm_type = "deis"
157 else:
158 raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
159
160 if solver_type not in ["logrho"]:
161 if solver_type in ["midpoint", "heun"]:
162 solver_type = "logrho"
163 else:
164 raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
165
166 # setable values
167 self.num_inference_steps = None
168 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
169 self.timesteps = torch.from_numpy(timesteps)
170 self.model_outputs = [None] * solver_order
171 self.lower_order_nums = 0
172
173 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
174 """
175 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
176
177 Args:
178 num_inference_steps (`int`):
179 the number of diffusion steps used when generating samples with a pre-trained model.
180 device (`str` or `torch.device`, optional):
181 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
182 """
183 self.num_inference_steps = num_inference_steps
184 timesteps = (
185 np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
186 .round()[::-1][:-1]
187 .copy()
188 .astype(np.int64)
189 )
190 self.timesteps = torch.from_numpy(timesteps).to(device)
191 self.model_outputs = [
192 None,
193 ] * self.config.solver_order
194 self.lower_order_nums = 0
195
196 def convert_model_output(
197 self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
198 ) -> torch.FloatTensor:
199 """
200 Convert the model output to the corresponding type that the algorithm DEIS needs.
201
202 Args:
203 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
204 timestep (`int`): current discrete timestep in the diffusion chain.
205 sample (`torch.FloatTensor`):
206 current instance of sample being created by diffusion process.
207
208 Returns:
209 `torch.FloatTensor`: the converted model output.
210 """
211 if self.config.prediction_type == "epsilon":
212 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
213 x0_pred = (sample - sigma_t * model_output) / alpha_t
214 elif self.config.prediction_type == "sample":
215 x0_pred = model_output
216 elif self.config.prediction_type == "v_prediction":
217 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
218 x0_pred = alpha_t * sample - sigma_t * model_output
219 else:
220 raise ValueError(
221 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
222 " `v_prediction` for the DEISMultistepScheduler."
223 )
224
225 if self.config.thresholding:
226 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
227 orig_dtype = x0_pred.dtype
228 if orig_dtype not in [torch.float, torch.double]:
229 x0_pred = x0_pred.float()
230 dynamic_max_val = torch.quantile(
231 torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
232 )
233 dynamic_max_val = torch.maximum(
234 dynamic_max_val,
235 self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
236 )[(...,) + (None,) * (x0_pred.ndim - 1)]
237 x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
238 x0_pred = x0_pred.type(orig_dtype)
239
240 if self.config.algorithm_type == "deis":
241 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
242 return (sample - alpha_t * x0_pred) / sigma_t
243 else:
244 raise NotImplementedError("only support log-rho multistep deis now")
245
246 def deis_first_order_update(
247 self,
248 model_output: torch.FloatTensor,
249 timestep: int,
250 prev_timestep: int,
251 sample: torch.FloatTensor,
252 ) -> torch.FloatTensor:
253 """
254 One step for the first-order DEIS (equivalent to DDIM).
255
256 Args:
257 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
258 timestep (`int`): current discrete timestep in the diffusion chain.
259 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
260 sample (`torch.FloatTensor`):
261 current instance of sample being created by diffusion process.
262
263 Returns:
264 `torch.FloatTensor`: the sample tensor at the previous timestep.
265 """
266 lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
267 alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
268 sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep]
269 h = lambda_t - lambda_s
270 if self.config.algorithm_type == "deis":
271 x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
272 else:
273 raise NotImplementedError("only support log-rho multistep deis now")
274 return x_t
275
276 def multistep_deis_second_order_update(
277 self,
278 model_output_list: List[torch.FloatTensor],
279 timestep_list: List[int],
280 prev_timestep: int,
281 sample: torch.FloatTensor,
282 ) -> torch.FloatTensor:
283 """
284 One step for the second-order multistep DEIS.
285
286 Args:
287 model_output_list (`List[torch.FloatTensor]`):
288 direct outputs from learned diffusion model at current and latter timesteps.
289 timestep (`int`): current and latter discrete timestep in the diffusion chain.
290 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
291 sample (`torch.FloatTensor`):
292 current instance of sample being created by diffusion process.
293
294 Returns:
295 `torch.FloatTensor`: the sample tensor at the previous timestep.
296 """
297 t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
298 m0, m1 = model_output_list[-1], model_output_list[-2]
299 alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
300 sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
301
302 rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
303
304 if self.config.algorithm_type == "deis":
305
306 def ind_fn(t, b, c):
307 # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}]
308 return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c))
309
310 coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1)
311 coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0)
312
313 x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1)
314 return x_t
315 else:
316 raise NotImplementedError("only support log-rho multistep deis now")
317
318 def multistep_deis_third_order_update(
319 self,
320 model_output_list: List[torch.FloatTensor],
321 timestep_list: List[int],
322 prev_timestep: int,
323 sample: torch.FloatTensor,
324 ) -> torch.FloatTensor:
325 """
326 One step for the third-order multistep DEIS.
327
328 Args:
329 model_output_list (`List[torch.FloatTensor]`):
330 direct outputs from learned diffusion model at current and latter timesteps.
331 timestep (`int`): current and latter discrete timestep in the diffusion chain.
332 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
333 sample (`torch.FloatTensor`):
334 current instance of sample being created by diffusion process.
335
336 Returns:
337 `torch.FloatTensor`: the sample tensor at the previous timestep.
338 """
339 t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
340 m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
341 alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
342 sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
343 rho_t, rho_s0, rho_s1, rho_s2 = (
344 sigma_t / alpha_t,
345 sigma_s0 / alpha_s0,
346 sigma_s1 / alpha_s1,
347 simga_s2 / alpha_s2,
348 )
349
350 if self.config.algorithm_type == "deis":
351
352 def ind_fn(t, b, c, d):
353 # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}]
354 numerator = t * (
355 np.log(c) * (np.log(d) - np.log(t) + 1)
356 - np.log(d) * np.log(t)
357 + np.log(d)
358 + np.log(t) ** 2
359 - 2 * np.log(t)
360 + 2
361 )
362 denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d))
363 return numerator / denominator
364
365 coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2)
366 coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0)
367 coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1)
368
369 x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2)
370
371 return x_t
372 else:
373 raise NotImplementedError("only support log-rho multistep deis now")
374
375 def step(
376 self,
377 model_output: torch.FloatTensor,
378 timestep: int,
379 sample: torch.FloatTensor,
380 return_dict: bool = True,
381 ) -> Union[SchedulerOutput, Tuple]:
382 """
383 Step function propagating the sample with the multistep DEIS.
384
385 Args:
386 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
387 timestep (`int`): current discrete timestep in the diffusion chain.
388 sample (`torch.FloatTensor`):
389 current instance of sample being created by diffusion process.
390 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
391
392 Returns:
393 [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
394 True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
395
396 """
397 if self.num_inference_steps is None:
398 raise ValueError(
399 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
400 )
401
402 if isinstance(timestep, torch.Tensor):
403 timestep = timestep.to(self.timesteps.device)
404 step_index = (self.timesteps == timestep).nonzero()
405 if len(step_index) == 0:
406 step_index = len(self.timesteps) - 1
407 else:
408 step_index = step_index.item()
409 prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
410 lower_order_final = (
411 (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
412 )
413 lower_order_second = (
414 (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
415 )
416
417 model_output = self.convert_model_output(model_output, timestep, sample)
418 for i in range(self.config.solver_order - 1):
419 self.model_outputs[i] = self.model_outputs[i + 1]
420 self.model_outputs[-1] = model_output
421
422 if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
423 prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample)
424 elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
425 timestep_list = [self.timesteps[step_index - 1], timestep]
426 prev_sample = self.multistep_deis_second_order_update(
427 self.model_outputs, timestep_list, prev_timestep, sample
428 )
429 else:
430 timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
431 prev_sample = self.multistep_deis_third_order_update(
432 self.model_outputs, timestep_list, prev_timestep, sample
433 )
434
435 if self.lower_order_nums < self.config.solver_order:
436 self.lower_order_nums += 1
437
438 if not return_dict:
439 return (prev_sample,)
440
441 return SchedulerOutput(prev_sample=prev_sample)
442
443 def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
444 """
445 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
446 current timestep.
447
448 Args:
449 sample (`torch.FloatTensor`): input sample
450
451 Returns:
452 `torch.FloatTensor`: scaled input sample
453 """
454 return sample
455
456 def add_noise(
457 self,
458 original_samples: torch.FloatTensor,
459 noise: torch.FloatTensor,
460 timesteps: torch.IntTensor,
461 ) -> torch.FloatTensor:
462 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
463 self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
464 timesteps = timesteps.to(original_samples.device)
465
466 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
467 sqrt_alpha_prod = sqrt_alpha_prod.flatten()
468 while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
469 sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
470
471 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
472 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
473 while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
474 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
475
476 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
477 return noisy_samples
478
479 def get_velocity(
480 self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
481 ) -> torch.FloatTensor:
482 # Make sure alphas_cumprod and timestep have same device and dtype as sample
483 self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
484 timesteps = timesteps.to(sample.device)
485
486 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
487 sqrt_alpha_prod = sqrt_alpha_prod.flatten()
488 while len(sqrt_alpha_prod.shape) < len(sample.shape):
489 sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
490
491 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
492 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
493 while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
494 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
495
496 velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
497 return velocity
498
499 def __len__(self):
500 return self.config.num_train_timesteps
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a7911c..8f0c6ea 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -285,9 +285,10 @@ def parse_args():
285 default=0.9999 285 default=0.9999
286 ) 286 )
287 parser.add_argument( 287 parser.add_argument(
288 "--use_8bit_adam", 288 "--optimizer",
289 action="store_true", 289 type=str,
290 help="Whether or not to use 8-bit Adam from bitsandbytes." 290 default="lion",
291 help='Optimizer to use ["adam", "adam8bit", "lion"]'
291 ) 292 )
292 parser.add_argument( 293 parser.add_argument(
293 "--adam_beta1", 294 "--adam_beta1",
@@ -491,15 +492,34 @@ def main():
491 args.learning_rate = 1e-6 492 args.learning_rate = 1e-6
492 args.lr_scheduler = "exponential_growth" 493 args.lr_scheduler = "exponential_growth"
493 494
494 if args.use_8bit_adam: 495 if args.optimizer == 'adam8bit':
495 try: 496 try:
496 import bitsandbytes as bnb 497 import bitsandbytes as bnb
497 except ImportError: 498 except ImportError:
498 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 499 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
499 500
500 optimizer_class = bnb.optim.AdamW8bit 501 create_optimizer = partial(
502 bnb.optim.AdamW8bit,
503 betas=(args.adam_beta1, args.adam_beta2),
504 weight_decay=args.adam_weight_decay,
505 eps=args.adam_epsilon,
506 amsgrad=args.adam_amsgrad,
507 )
508 elif args.optimizer == 'adam':
509 create_optimizer = partial(
510 torch.optim.AdamW,
511 betas=(args.adam_beta1, args.adam_beta2),
512 weight_decay=args.adam_weight_decay,
513 eps=args.adam_epsilon,
514 amsgrad=args.adam_amsgrad,
515 )
501 else: 516 else:
502 optimizer_class = torch.optim.AdamW 517 try:
518 from lion_pytorch import Lion
519 except ImportError:
520 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
521
522 create_optimizer = partial(Lion, use_triton=True)
503 523
504 trainer = partial( 524 trainer = partial(
505 train, 525 train,
@@ -540,17 +560,13 @@ def main():
540 ) 560 )
541 datamodule.setup() 561 datamodule.setup()
542 562
543 optimizer = optimizer_class( 563 optimizer = create_optimizer(
544 itertools.chain( 564 itertools.chain(
545 unet.parameters(), 565 unet.parameters(),
546 text_encoder.text_model.encoder.parameters(), 566 text_encoder.text_model.encoder.parameters(),
547 text_encoder.text_model.final_layer_norm.parameters(), 567 text_encoder.text_model.final_layer_norm.parameters(),
548 ), 568 ),
549 lr=args.learning_rate, 569 lr=args.learning_rate,
550 betas=(args.adam_beta1, args.adam_beta2),
551 weight_decay=args.adam_weight_decay,
552 eps=args.adam_epsilon,
553 amsgrad=args.adam_amsgrad,
554 ) 570 )
555 571
556 lr_scheduler = get_scheduler( 572 lr_scheduler = get_scheduler(
diff --git a/train_lora.py b/train_lora.py
index 330bcd6..368c29b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -245,9 +245,10 @@ def parse_args():
245 help="Minimum learning rate in the lr scheduler." 245 help="Minimum learning rate in the lr scheduler."
246 ) 246 )
247 parser.add_argument( 247 parser.add_argument(
248 "--use_8bit_adam", 248 "--optimizer",
249 action="store_true", 249 type=str,
250 help="Whether or not to use 8-bit Adam from bitsandbytes." 250 default="lion",
251 help='Optimizer to use ["adam", "adam8bit", "lion"]'
251 ) 252 )
252 parser.add_argument( 253 parser.add_argument(
253 "--adam_beta1", 254 "--adam_beta1",
@@ -466,15 +467,34 @@ def main():
466 args.learning_rate = 1e-6 467 args.learning_rate = 1e-6
467 args.lr_scheduler = "exponential_growth" 468 args.lr_scheduler = "exponential_growth"
468 469
469 if args.use_8bit_adam: 470 if args.optimizer == 'adam8bit':
470 try: 471 try:
471 import bitsandbytes as bnb 472 import bitsandbytes as bnb
472 except ImportError: 473 except ImportError:
473 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 474 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
474 475
475 optimizer_class = bnb.optim.AdamW8bit 476 create_optimizer = partial(
477 bnb.optim.AdamW8bit,
478 betas=(args.adam_beta1, args.adam_beta2),
479 weight_decay=args.adam_weight_decay,
480 eps=args.adam_epsilon,
481 amsgrad=args.adam_amsgrad,
482 )
483 elif args.optimizer == 'adam':
484 create_optimizer = partial(
485 torch.optim.AdamW,
486 betas=(args.adam_beta1, args.adam_beta2),
487 weight_decay=args.adam_weight_decay,
488 eps=args.adam_epsilon,
489 amsgrad=args.adam_amsgrad,
490 )
476 else: 491 else:
477 optimizer_class = torch.optim.AdamW 492 try:
493 from lion_pytorch import Lion
494 except ImportError:
495 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
496
497 create_optimizer = partial(Lion, use_triton=True)
478 498
479 trainer = partial( 499 trainer = partial(
480 train, 500 train,
@@ -516,13 +536,9 @@ def main():
516 ) 536 )
517 datamodule.setup() 537 datamodule.setup()
518 538
519 optimizer = optimizer_class( 539 optimizer = create_optimizer(
520 lora_layers.parameters(), 540 lora_layers.parameters(),
521 lr=args.learning_rate, 541 lr=args.learning_rate,
522 betas=(args.adam_beta1, args.adam_beta2),
523 weight_decay=args.adam_weight_decay,
524 eps=args.adam_epsilon,
525 amsgrad=args.adam_amsgrad,
526 ) 542 )
527 543
528 lr_scheduler = get_scheduler( 544 lr_scheduler = get_scheduler(
diff --git a/train_ti.py b/train_ti.py
index 3aa1027..507d710 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -290,9 +290,10 @@ def parse_args():
290 default=0.9999 290 default=0.9999
291 ) 291 )
292 parser.add_argument( 292 parser.add_argument(
293 "--use_8bit_adam", 293 "--optimizer",
294 action="store_true", 294 type=str,
295 help="Whether or not to use 8-bit Adam from bitsandbytes." 295 default="lion",
296 help='Optimizer to use ["adam", "adam8bit", "lion"]'
296 ) 297 )
297 parser.add_argument( 298 parser.add_argument(
298 "--adam_beta1", 299 "--adam_beta1",
@@ -564,15 +565,34 @@ def main():
564 args.learning_rate = 1e-5 565 args.learning_rate = 1e-5
565 args.lr_scheduler = "exponential_growth" 566 args.lr_scheduler = "exponential_growth"
566 567
567 if args.use_8bit_adam: 568 if args.optimizer == 'adam8bit':
568 try: 569 try:
569 import bitsandbytes as bnb 570 import bitsandbytes as bnb
570 except ImportError: 571 except ImportError:
571 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 572 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
572 573
573 optimizer_class = bnb.optim.AdamW8bit 574 create_optimizer = partial(
575 bnb.optim.AdamW8bit,
576 betas=(args.adam_beta1, args.adam_beta2),
577 weight_decay=args.adam_weight_decay,
578 eps=args.adam_epsilon,
579 amsgrad=args.adam_amsgrad,
580 )
581 elif args.optimizer == 'adam':
582 create_optimizer = partial(
583 torch.optim.AdamW,
584 betas=(args.adam_beta1, args.adam_beta2),
585 weight_decay=args.adam_weight_decay,
586 eps=args.adam_epsilon,
587 amsgrad=args.adam_amsgrad,
588 )
574 else: 589 else:
575 optimizer_class = torch.optim.AdamW 590 try:
591 from lion_pytorch import Lion
592 except ImportError:
593 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.")
594
595 create_optimizer = partial(Lion, use_triton=True)
576 596
577 checkpoint_output_dir = output_dir/"checkpoints" 597 checkpoint_output_dir = output_dir/"checkpoints"
578 598
@@ -658,13 +678,9 @@ def main():
658 ) 678 )
659 datamodule.setup() 679 datamodule.setup()
660 680
661 optimizer = optimizer_class( 681 optimizer = create_optimizer(
662 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 682 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
663 lr=args.learning_rate, 683 lr=args.learning_rate,
664 betas=(args.adam_beta1, args.adam_beta2),
665 weight_decay=args.adam_weight_decay,
666 eps=args.adam_epsilon,
667 amsgrad=args.adam_amsgrad,
668 ) 684 )
669 685
670 lr_scheduler = get_scheduler( 686 lr_scheduler = get_scheduler(
diff --git a/training/functional.py b/training/functional.py
index 41794ea..4d0cf0e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler 15from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_deis_multistep import DEISMultistepScheduler
25from training.util import AverageMeter 26from training.util import AverageMeter
26 27
27 28
@@ -78,7 +79,7 @@ def get_models(pretrained_model_name_or_path: str):
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 79 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
80 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
81 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 82 noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 83 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder='scheduler')
84 85
@@ -251,7 +252,7 @@ def add_placeholder_tokens(
251 252
252def loss_step( 253def loss_step(
253 vae: AutoencoderKL, 254 vae: AutoencoderKL,
254 noise_scheduler: DDPMScheduler, 255 noise_scheduler: DEISMultistepScheduler,
255 unet: UNet2DConditionModel, 256 unet: UNet2DConditionModel,
256 text_encoder: CLIPTextModel, 257 text_encoder: CLIPTextModel,
257 with_prior_preservation: bool, 258 with_prior_preservation: bool,
@@ -551,7 +552,7 @@ def train(
551 unet: UNet2DConditionModel, 552 unet: UNet2DConditionModel,
552 text_encoder: CLIPTextModel, 553 text_encoder: CLIPTextModel,
553 vae: AutoencoderKL, 554 vae: AutoencoderKL,
554 noise_scheduler: DDPMScheduler, 555 noise_scheduler: DEISMultistepScheduler,
555 dtype: torch.dtype, 556 dtype: torch.dtype,
556 seed: int, 557 seed: int,
557 project: str, 558 project: str,