summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml18
-rw-r--r--infer.py7
-rw-r--r--schedulers/scheduling_unipc_multistep.py615
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py4
-rw-r--r--train_ti.py6
-rw-r--r--training/functional.py3
-rw-r--r--training/strategy/dreambooth.py8
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py4
10 files changed, 27 insertions, 644 deletions
diff --git a/environment.yaml b/environment.yaml
index f5632bf..8010c09 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,28 +1,24 @@
1name: ldd 1name: ldd
2channels: 2channels:
3 - pytorch 3 - pytorch-nightly
4 - nvidia 4 - nvidia
5 - xformers/label/dev 5 - xformers/label/dev
6 - defaults 6 - defaults
7dependencies: 7dependencies:
8 - cudatoolkit=11.3 8 - cudatoolkit=11.7
9 - libcufile=1.4.0.31
10 - matplotlib=3.6.2 9 - matplotlib=3.6.2
11 - numpy=1.23.4 10 - numpy=1.23.4
12 - pip=22.3.1 11 - pip=22.3.1
13 - python=3.10.8 12 - python=3.10.8
14 - pytorch=1.13.1=*cuda* 13 - pytorch=2.0.0.dev20230216=*cuda*
15 - torchvision=0.14.1 14 - torchvision=0.15.0.dev20230216
16 - pip: 15 - pip:
17 - -e . 16 - -e .
18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
19 - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion 18 - accelerate==0.16.0
20 - accelerate==0.15.0
21 - bitsandbytes==0.37.0 19 - bitsandbytes==0.37.0
22 - python-slugify>=6.1.2 20 - python-slugify>=6.1.2
23 - safetensors==0.2.7 21 - safetensors==0.2.8
24 - setuptools==65.6.3 22 - setuptools==65.6.3
25 - test-tube>=0.7.5 23 - test-tube>=0.7.5
26 - transformers==4.25.1 24 - transformers==4.26.1
27 - triton==2.0.0.dev20221202
28 - xformers==0.0.17.dev443
diff --git a/infer.py b/infer.py
index 329c60b..13219f8 100644
--- a/infer.py
+++ b/infer.py
@@ -21,7 +21,8 @@ from diffusers import (
21 LMSDiscreteScheduler, 21 LMSDiscreteScheduler,
22 EulerAncestralDiscreteScheduler, 22 EulerAncestralDiscreteScheduler,
23 KDPM2DiscreteScheduler, 23 KDPM2DiscreteScheduler,
24 KDPM2AncestralDiscreteScheduler 24 KDPM2AncestralDiscreteScheduler,
25 UniPCMultistepScheduler
25) 26)
26from transformers import CLIPTextModel 27from transformers import CLIPTextModel
27 28
@@ -29,7 +30,6 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
29from models.clip.embeddings import patch_managed_embeddings 30from models.clip.embeddings import patch_managed_embeddings
30from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 32from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
32from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
33from util import load_config, load_embeddings_from_dir 33from util import load_config, load_embeddings_from_dir
34 34
35 35
@@ -245,7 +245,8 @@ def create_pipeline(model, dtype):
245 tokenizer=tokenizer, 245 tokenizer=tokenizer,
246 scheduler=scheduler, 246 scheduler=scheduler,
247 ) 247 )
248 pipeline.enable_xformers_memory_efficient_attention() 248 # pipeline.enable_xformers_memory_efficient_attention()
249 pipeline.unet = torch.compile(pipeline.unet)
249 pipeline.enable_vae_slicing() 250 pipeline.enable_vae_slicing()
250 pipeline.to("cuda") 251 pipeline.to("cuda")
251 252
diff --git a/schedulers/scheduling_unipc_multistep.py b/schedulers/scheduling_unipc_multistep.py
deleted file mode 100644
index ff5db24..0000000
--- a/schedulers/scheduling_unipc_multistep.py
+++ /dev/null
@@ -1,615 +0,0 @@
1# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
17import math
18from typing import List, Optional, Union
19
20import numpy as np
21import torch
22
23from diffusers.configuration_utils import ConfigMixin, register_to_config
24from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
26
27def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
28 """
29 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
30 (1-beta) over time from t = [0,1].
31
32 Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
33 to that part of the diffusion process.
34
35
36 Args:
37 num_diffusion_timesteps (`int`): the number of betas to produce.
38 max_beta (`float`): the maximum beta to use; use values lower than 1 to
39 prevent singularities.
40
41 Returns:
42 betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43 """
44
45 def alpha_bar(time_step):
46 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
47
48 betas = []
49 for i in range(num_diffusion_timesteps):
50 t1 = i / num_diffusion_timesteps
51 t2 = (i + 1) / num_diffusion_timesteps
52 betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
53 return torch.tensor(betas, dtype=torch.float32)
54
55
56class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
57 """
58 UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of
59 a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
60 UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional
61 sampling. It can also be applied to both noise prediction model and data prediction model. The corrector
62 UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy.
63
64 For more details, see the original paper: https://arxiv.org/abs/2302.04867
65
66 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We
67 recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
68
69 We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
70 diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
71 thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
72 stable-diffusion).
73
74 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
75 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
76 [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
77 [`~SchedulerMixin.from_pretrained`] functions.
78
79 Args:
80 num_train_timesteps (`int`): number of diffusion steps used to train the model.
81 beta_start (`float`): the starting `beta` value of inference.
82 beta_end (`float`): the final `beta` value.
83 beta_schedule (`str`):
84 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
85 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
86 trained_betas (`np.ndarray`, optional):
87 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
88 solver_order (`int`, default `2`):
89 the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
90 accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided
91 sampling, and `solver_order=3` for unconditional sampling.
92 prediction_type (`str`, default `epsilon`, optional):
93 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
94 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
95 https://imagen.research.google/video/paper.pdf)
96 thresholding (`bool`, default `False`):
97 whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
98 For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
99 use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
100 models (such as stable-diffusion).
101 dynamic_thresholding_ratio (`float`, default `0.995`):
102 the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
103 (https://arxiv.org/abs/2205.11487).
104 sample_max_value (`float`, default `1.0`):
105 the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
106 `predict_x0=True`.
107 predict_x0 (`bool`, default `True`):
108 whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
109 solver_type (`str`, default `bh1`):
110 the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use
111 `bh2` otherwise.
112 lower_order_final (`bool`, default `True`):
113 whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
114 find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
115 disable_corrector (`list`, default `[]`):
116 decide which step to disable the corrector. For large guidance scale, the misalignment between the
117 `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be
118 mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0])
119 solver_p (`SchedulerMixin`):
120 can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
121 """
122
123 _compatibles = [e.name for e in KarrasDiffusionSchedulers]
124 order = 1
125
126 @register_to_config
127 def __init__(
128 self,
129 num_train_timesteps: int = 1000,
130 beta_start: float = 0.0001,
131 beta_end: float = 0.02,
132 beta_schedule: str = "linear",
133 trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
134 solver_order: int = 2,
135 prediction_type: str = "epsilon",
136 thresholding: bool = False,
137 dynamic_thresholding_ratio: float = 0.995,
138 sample_max_value: float = 1.0,
139 predict_x0: bool = True,
140 solver_type: str = "bh1",
141 lower_order_final: bool = True,
142 disable_corrector: List[int] = [],
143 solver_p: SchedulerMixin = None,
144 ):
145 if trained_betas is not None:
146 self.betas = torch.tensor(trained_betas, dtype=torch.float32)
147 elif beta_schedule == "linear":
148 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
149 elif beta_schedule == "scaled_linear":
150 # this schedule is very specific to the latent diffusion model.
151 self.betas = (
152 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
153 )
154 elif beta_schedule == "squaredcos_cap_v2":
155 # Glide cosine schedule
156 self.betas = betas_for_alpha_bar(num_train_timesteps)
157 else:
158 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
159
160 self.alphas = 1.0 - self.betas
161 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
162 # Currently we only support VP-type noise schedule
163 self.alpha_t = torch.sqrt(self.alphas_cumprod)
164 self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
165 self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
166
167 # standard deviation of the initial noise distribution
168 self.init_noise_sigma = 1.0
169
170 if solver_type not in ["bh1", "bh2"]:
171 raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
172
173 self.predict_x0 = predict_x0
174 # setable values
175 self.num_inference_steps = None
176 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
177 self.timesteps = torch.from_numpy(timesteps)
178 self.model_outputs = [None] * solver_order
179 self.timestep_list = [None] * solver_order
180 self.lower_order_nums = 0
181 self.disable_corrector = disable_corrector
182 self.solver_p = solver_p
183
184 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
185 """
186 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
187
188 Args:
189 num_inference_steps (`int`):
190 the number of diffusion steps used when generating samples with a pre-trained model.
191 device (`str` or `torch.device`, optional):
192 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
193 """
194 self.num_inference_steps = num_inference_steps
195 timesteps = (
196 np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
197 .round()[::-1][:-1]
198 .copy()
199 .astype(np.int64)
200 )
201 self.timesteps = torch.from_numpy(timesteps).to(device)
202 self.model_outputs = [
203 None,
204 ] * self.config.solver_order
205 self.lower_order_nums = 0
206 if self.solver_p:
207 self.solver_p.set_timesteps(num_inference_steps, device=device)
208
209 def convert_model_output(
210 self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
211 ):
212 r"""
213 Convert the model output to the corresponding type that the algorithm PC needs.
214
215 Args:
216 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
217 timestep (`int`): current discrete timestep in the diffusion chain.
218 sample (`torch.FloatTensor`):
219 current instance of sample being created by diffusion process.
220
221 Returns:
222 `torch.FloatTensor`: the converted model output.
223 """
224 if self.predict_x0:
225 if self.config.prediction_type == "epsilon":
226 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
227 x0_pred = (sample - sigma_t * model_output) / alpha_t
228 elif self.config.prediction_type == "sample":
229 x0_pred = model_output
230 elif self.config.prediction_type == "v_prediction":
231 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
232 x0_pred = alpha_t * sample - sigma_t * model_output
233 else:
234 raise ValueError(
235 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
236 " `v_prediction` for the DPMSolverMultistepScheduler."
237 )
238
239 if self.config.thresholding:
240 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
241 orig_dtype = x0_pred.dtype
242 if orig_dtype not in [torch.float, torch.double]:
243 x0_pred = x0_pred.float()
244 dynamic_max_val = torch.quantile(
245 torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
246 )
247 dynamic_max_val = torch.maximum(
248 dynamic_max_val,
249 self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
250 )[(...,) + (None,) * (x0_pred.ndim - 1)]
251 x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
252 x0_pred = x0_pred.type(orig_dtype)
253 return x0_pred
254 else:
255 if self.config.prediction_type == "epsilon":
256 return model_output
257 elif self.config.prediction_type == "sample":
258 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
259 epsilon = (sample - alpha_t * model_output) / sigma_t
260 return epsilon
261 elif self.config.prediction_type == "v_prediction":
262 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
263 epsilon = alpha_t * model_output + sigma_t * sample
264 return epsilon
265 else:
266 raise ValueError(
267 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
268 " `v_prediction` for the DPMSolverMultistepScheduler."
269 )
270
271 def multistep_uni_p_bh_update(
272 self,
273 model_output: torch.FloatTensor,
274 prev_timestep: int,
275 sample: torch.FloatTensor,
276 order: int,
277 ):
278 """
279 One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
280
281 Args:
282 model_output (`torch.FloatTensor`):
283 direct outputs from learned diffusion model at the current timestep.
284 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
285 sample (`torch.FloatTensor`):
286 current instance of sample being created by diffusion process.
287 order (`int`): the order of UniP at this step, also the p in UniPC-p.
288
289 Returns:
290 `torch.FloatTensor`: the sample tensor at the previous timestep.
291 """
292 timestep_list = self.timestep_list
293 model_output_list = self.model_outputs
294
295 s0, t = self.timestep_list[-1], prev_timestep
296 m0 = model_output_list[-1]
297 x = sample
298
299 if self.solver_p:
300 x_t = self.solver_p.step(model_output, s0, x).prev_sample
301 return x_t
302
303 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
304 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
305 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
306
307 h = lambda_t - lambda_s0
308 device = sample.device
309
310 rks = []
311 D1s = []
312 for i in range(1, order):
313 si = timestep_list[-(i + 1)]
314 mi = model_output_list[-(i + 1)]
315 lambda_si = self.lambda_t[si]
316 rk = ((lambda_si - lambda_s0) / h)
317 rks.append(rk)
318 D1s.append((mi - m0) / rk)
319
320 rks.append(1.)
321 rks = torch.tensor(rks, device=device)
322
323 R = []
324 b = []
325
326 hh = -h if self.predict_x0 else h
327 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
328 h_phi_k = h_phi_1 / hh - 1
329
330 factorial_i = 1
331
332 if self.config.solver_type == 'bh1':
333 B_h = hh
334 elif self.config.solver_type == 'bh2':
335 B_h = torch.expm1(hh)
336 else:
337 raise NotImplementedError()
338
339 for i in range(1, order + 1):
340 R.append(torch.pow(rks, i - 1))
341 b.append(h_phi_k * factorial_i / B_h)
342 factorial_i *= (i + 1)
343 h_phi_k = h_phi_k / hh - 1 / factorial_i
344
345 R = torch.stack(R)
346 b = torch.tensor(b, device=device)
347
348 if len(D1s) > 0:
349 D1s = torch.stack(D1s, dim=1) # (B, K)
350 # for order 2, we use a simplified version
351 if order == 2:
352 rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
353 else:
354 rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
355 else:
356 D1s = None
357
358 if self.predict_x0:
359 x_t_ = (
360 sigma_t / sigma_s0 * x
361 - alpha_t * h_phi_1 * m0
362 )
363 if D1s is not None:
364 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
365 else:
366 pred_res = 0
367 x_t = x_t_ - alpha_t * B_h * pred_res
368 else:
369 x_t_ = (
370 alpha_t / alpha_s0 * x
371 - sigma_t * h_phi_1 * m0
372 )
373 if D1s is not None:
374 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
375 else:
376 pred_res = 0
377 x_t = x_t_ - sigma_t * B_h * pred_res
378
379 x_t = x_t.to(x.dtype)
380 return x_t
381
382 def multistep_uni_c_bh_update(
383 self,
384 this_model_output: torch.FloatTensor,
385 this_timestep: int,
386 last_sample: torch.FloatTensor,
387 this_sample: torch.FloatTensor,
388 order: int,
389 ):
390 """
391 One step for the UniC (B(h) version).
392
393 Args:
394 this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
395 this_timestep (`int`): the current timestep `t`
396 last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
397 this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
398 order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
399 should be order + 1
400
401 Returns:
402 `torch.FloatTensor`: the corrected sample tensor at the current timestep.
403 """
404 timestep_list = self.timestep_list
405 model_output_list = self.model_outputs
406
407 s0, t = timestep_list[-1], this_timestep
408 m0 = model_output_list[-1]
409 x = last_sample
410 x_t = this_sample
411 model_t = this_model_output
412
413 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
414 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
415 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
416
417 h = lambda_t - lambda_s0
418 device = this_sample.device
419
420 rks = []
421 D1s = []
422 for i in range(1, order):
423 si = timestep_list[-(i + 1)]
424 mi = model_output_list[-(i + 1)]
425 lambda_si = self.lambda_t[si]
426 rk = ((lambda_si - lambda_s0) / h)
427 rks.append(rk)
428 D1s.append((mi - m0) / rk)
429
430 rks.append(1.)
431 rks = torch.tensor(rks, device=device)
432
433 R = []
434 b = []
435
436 hh = -h if self.predict_x0 else h
437 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
438 h_phi_k = h_phi_1 / hh - 1
439
440 factorial_i = 1
441
442 if self.config.solver_type == 'bh1':
443 B_h = hh
444 elif self.config.solver_type == 'bh2':
445 B_h = torch.expm1(hh)
446 else:
447 raise NotImplementedError()
448
449 for i in range(1, order + 1):
450 R.append(torch.pow(rks, i - 1))
451 b.append(h_phi_k * factorial_i / B_h)
452 factorial_i *= (i + 1)
453 h_phi_k = h_phi_k / hh - 1 / factorial_i
454
455 R = torch.stack(R)
456 b = torch.tensor(b, device=device)
457
458 if len(D1s) > 0:
459 D1s = torch.stack(D1s, dim=1)
460 else:
461 D1s = None
462
463 # for order 1, we use a simplified version
464 if order == 1:
465 rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
466 else:
467 rhos_c = torch.linalg.solve(R, b)
468
469 if self.predict_x0:
470 x_t_ = (
471 sigma_t / sigma_s0 * x
472 - alpha_t * h_phi_1 * m0
473 )
474 if D1s is not None:
475 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
476 else:
477 corr_res = 0
478 D1_t = (model_t - m0)
479 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
480 else:
481 x_t_ = (
482 alpha_t / alpha_s0 * x
483 - sigma_t * h_phi_1 * m0
484 )
485 if D1s is not None:
486 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
487 else:
488 corr_res = 0
489 D1_t = (model_t - m0)
490 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
491 x_t = x_t.to(x.dtype)
492 return x_t
493
494 def step(
495 self,
496 model_output: torch.FloatTensor,
497 timestep: int,
498 sample: torch.FloatTensor,
499 return_dict: bool = True,
500 ):
501 # -> Union[SchedulerOutput, Tuple]:
502 """
503 Step function propagating the sample with the multistep UniPC.
504
505 Args:
506 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
507 timestep (`int`): current discrete timestep in the diffusion chain.
508 sample (`torch.FloatTensor`):
509 current instance of sample being created by diffusion process.
510 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
511
512 Returns:
513 [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
514 True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
515
516 """
517
518 if self.num_inference_steps is None:
519 raise ValueError(
520 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
521 )
522
523 if isinstance(timestep, torch.Tensor):
524 timestep = timestep.to(self.timesteps.device)
525 step_index = (self.timesteps == timestep).nonzero()
526 if len(step_index) == 0:
527 step_index = len(self.timesteps) - 1
528 else:
529 step_index = step_index.item()
530
531 use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector # step_index not in self.disable_corrector
532
533 model_output_convert = self.convert_model_output(model_output, timestep, sample)
534 if use_corrector:
535 sample = self.multistep_uni_c_bh_update(
536 this_model_output=model_output_convert,
537 this_timestep=timestep,
538 last_sample=self.last_sample,
539 this_sample=sample,
540 order=self.this_order,
541 )
542
543 # now prepare to run the predictor
544 prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
545
546 for i in range(self.config.solver_order - 1):
547 self.model_outputs[i] = self.model_outputs[i + 1]
548 self.timestep_list[i] = self.timestep_list[i + 1]
549
550 self.model_outputs[-1] = model_output_convert
551 self.timestep_list[-1] = timestep
552
553 if self.config.lower_order_final:
554 this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
555 else:
556 this_order = self.config.solver_order
557
558 self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
559 assert self.this_order > 0
560
561 self.last_sample = sample
562 prev_sample = self.multistep_uni_p_bh_update(
563 model_output=model_output, # pass the original non-converted model output, in case solver-p is used
564 prev_timestep=prev_timestep,
565 sample=sample,
566 order=self.this_order,
567 )
568
569 if self.lower_order_nums < self.config.solver_order:
570 self.lower_order_nums += 1
571
572 if not return_dict:
573 return (prev_sample,)
574
575 return SchedulerOutput(prev_sample=prev_sample)
576
577 def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs): # -> torch.FloatTensor:
578 """
579 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
580 current timestep.
581
582 Args:
583 sample (`torch.FloatTensor`): input sample
584
585 Returns:
586 `torch.FloatTensor`: scaled input sample
587 """
588 return sample
589
590 def add_noise(
591 self,
592 original_samples: torch.FloatTensor,
593 noise: torch.FloatTensor,
594 timesteps: torch.IntTensor,
595 ):
596 # -> torch.FloatTensor:
597 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
598 self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
599 timesteps = timesteps.to(original_samples.device)
600
601 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
602 sqrt_alpha_prod = sqrt_alpha_prod.flatten()
603 while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
604 sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
605
606 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
607 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
608 while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
609 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
610
611 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
612 return noisy_samples
613
614 def __len__(self):
615 return self.config.num_train_timesteps
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a7911c..85b756c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -464,8 +464,8 @@ def main():
464 tokenizer.set_dropout(args.vector_dropout) 464 tokenizer.set_dropout(args.vector_dropout)
465 465
466 vae.enable_slicing() 466 vae.enable_slicing()
467 vae.set_use_memory_efficient_attention_xformers(True) 467 # vae.set_use_memory_efficient_attention_xformers(True)
468 unet.enable_xformers_memory_efficient_attention() 468 # unet.enable_xformers_memory_efficient_attention()
469 469
470 if args.gradient_checkpointing: 470 if args.gradient_checkpointing:
471 unet.enable_gradient_checkpointing() 471 unet.enable_gradient_checkpointing()
diff --git a/train_lora.py b/train_lora.py
index 330bcd6..8a06ae8 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -421,8 +421,8 @@ def main():
421 args.pretrained_model_name_or_path) 421 args.pretrained_model_name_or_path)
422 422
423 vae.enable_slicing() 423 vae.enable_slicing()
424 vae.set_use_memory_efficient_attention_xformers(True) 424 # vae.set_use_memory_efficient_attention_xformers(True)
425 unet.enable_xformers_memory_efficient_attention() 425 # unet.enable_xformers_memory_efficient_attention()
426 426
427 unet.to(accelerator.device, dtype=weight_dtype) 427 unet.to(accelerator.device, dtype=weight_dtype)
428 text_encoder.to(accelerator.device, dtype=weight_dtype) 428 text_encoder.to(accelerator.device, dtype=weight_dtype)
diff --git a/train_ti.py b/train_ti.py
index d1defb3..7d10317 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -538,8 +538,10 @@ def main():
538 tokenizer.set_dropout(args.vector_dropout) 538 tokenizer.set_dropout(args.vector_dropout)
539 539
540 vae.enable_slicing() 540 vae.enable_slicing()
541 vae.set_use_memory_efficient_attention_xformers(True) 541 # vae.set_use_memory_efficient_attention_xformers(True)
542 unet.enable_xformers_memory_efficient_attention() 542 # unet.enable_xformers_memory_efficient_attention()
543
544 # unet = torch.compile(unet)
543 545
544 if args.gradient_checkpointing: 546 if args.gradient_checkpointing:
545 unet.enable_gradient_checkpointing() 547 unet.enable_gradient_checkpointing()
diff --git a/training/functional.py b/training/functional.py
index 78a2b10..41794ea 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
26from training.util import AverageMeter 25from training.util import AverageMeter
27 26
28 27
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 8aaed3a..d697554 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks(
144 144
145 print("Saving model...") 145 print("Saving model...")
146 146
147 unet_ = accelerator.unwrap_model(unet) 147 unet_ = accelerator.unwrap_model(unet, False)
148 text_encoder_ = accelerator.unwrap_model(text_encoder) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
149 149
150 with ema_context(): 150 with ema_context():
151 pipeline = VlpnStableDiffusion( 151 pipeline = VlpnStableDiffusion(
@@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks(
167 @torch.no_grad() 167 @torch.no_grad()
168 def on_sample(step): 168 def on_sample(step):
169 with ema_context(): 169 with ema_context():
170 unet_ = accelerator.unwrap_model(unet) 170 unet_ = accelerator.unwrap_model(unet, False)
171 text_encoder_ = accelerator.unwrap_model(text_encoder) 171 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
172 172
173 orig_unet_dtype = unet_.dtype 173 orig_unet_dtype = unet_.dtype
174 orig_text_encoder_dtype = text_encoder_.dtype 174 orig_text_encoder_dtype = text_encoder_.dtype
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 4dd1100..ccec215 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -90,7 +90,7 @@ def lora_strategy_callbacks(
90 def on_checkpoint(step, postfix): 90 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 91 print(f"Saving checkpoint for step {step}...")
92 92
93 unet_ = accelerator.unwrap_model(unet) 93 unet_ = accelerator.unwrap_model(unet, False)
94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") 94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}")
95 del unet_ 95 del unet_
96 96
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 0de3cb0..66d3129 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -144,8 +144,8 @@ def textual_inversion_strategy_callbacks(
144 @torch.no_grad() 144 @torch.no_grad()
145 def on_sample(step): 145 def on_sample(step):
146 with ema_context(): 146 with ema_context():
147 unet_ = accelerator.unwrap_model(unet) 147 unet_ = accelerator.unwrap_model(unet, False)
148 text_encoder_ = accelerator.unwrap_model(text_encoder) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
149 149
150 orig_unet_dtype = unet_.dtype 150 orig_unet_dtype = unet_.dtype
151 orig_text_encoder_dtype = text_encoder_.dtype 151 orig_text_encoder_dtype = text_encoder_.dtype