summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
committerVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
commitd673760fc671d665aadae3b032f8e99f21ab986d (patch)
tree7c14a998742b19ddecac6ee25a669892b41c305e
parentUpdate (diff)
downloadtextual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.gz
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.bz2
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.zip
Integrated WIP UniPC scheduler
-rw-r--r--infer.py15
-rw-r--r--schedulers/scheduling_unipc_multistep.py615
-rw-r--r--train_dreambooth.py3
-rw-r--r--train_lora.py3
-rw-r--r--train_ti.py3
-rw-r--r--training/functional.py30
6 files changed, 655 insertions, 14 deletions
diff --git a/infer.py b/infer.py
index aa75ee5..329c60b 100644
--- a/infer.py
+++ b/infer.py
@@ -29,6 +29,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
29from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
32from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
32from util import load_config, load_embeddings_from_dir 33from util import load_config, load_embeddings_from_dir
33 34
34 35
@@ -61,6 +62,7 @@ default_cmds = {
61 "batch_num": 1, 62 "batch_num": 1,
62 "steps": 30, 63 "steps": 30,
63 "guidance_scale": 7.0, 64 "guidance_scale": 7.0,
65 "sag_scale": 0.75,
64 "lora_scale": 0.5, 66 "lora_scale": 0.5,
65 "seed": None, 67 "seed": None,
66 "config": None, 68 "config": None,
@@ -122,7 +124,7 @@ def create_cmd_parser():
122 parser.add_argument( 124 parser.add_argument(
123 "--scheduler", 125 "--scheduler",
124 type=str, 126 type=str,
125 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 127 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "unipc"],
126 ) 128 )
127 parser.add_argument( 129 parser.add_argument(
128 "--template", 130 "--template",
@@ -175,6 +177,10 @@ def create_cmd_parser():
175 type=float, 177 type=float,
176 ) 178 )
177 parser.add_argument( 179 parser.add_argument(
180 "--sag_scale",
181 type=float,
182 )
183 parser.add_argument(
178 "--lora_scale", 184 "--lora_scale",
179 type=float, 185 type=float,
180 ) 186 )
@@ -304,6 +310,8 @@ def generate(output_dir: Path, pipeline, args):
304 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) 310 pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
305 elif args.scheduler == "kdpm2_a": 311 elif args.scheduler == "kdpm2_a":
306 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 312 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
313 elif args.scheduler == "unipc":
314 pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
307 315
308 for i in range(args.batch_num): 316 for i in range(args.batch_num):
309 pipeline.set_progress_bar_config( 317 pipeline.set_progress_bar_config(
@@ -322,10 +330,11 @@ def generate(output_dir: Path, pipeline, args):
322 num_images_per_prompt=args.batch_size, 330 num_images_per_prompt=args.batch_size,
323 num_inference_steps=args.steps, 331 num_inference_steps=args.steps,
324 guidance_scale=args.guidance_scale, 332 guidance_scale=args.guidance_scale,
333 sag_scale=args.sag_scale,
325 generator=generator, 334 generator=generator,
326 image=init_image, 335 image=init_image,
327 strength=args.image_noise, 336 strength=args.image_noise,
328 cross_attention_kwargs={"scale": args.lora_scale}, 337 # cross_attention_kwargs={"scale": args.lora_scale},
329 ).images 338 ).images
330 339
331 for j, image in enumerate(images): 340 for j, image in enumerate(images):
@@ -408,7 +417,7 @@ def main():
408 pipeline = create_pipeline(args.model, dtype) 417 pipeline = create_pipeline(args.model, dtype)
409 418
410 load_embeddings(pipeline, args.ti_embeddings_dir) 419 load_embeddings(pipeline, args.ti_embeddings_dir)
411 pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 420 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
412 421
413 cmd_parser = create_cmd_parser() 422 cmd_parser = create_cmd_parser()
414 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) 423 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser)
diff --git a/schedulers/scheduling_unipc_multistep.py b/schedulers/scheduling_unipc_multistep.py
new file mode 100644
index 0000000..ff5db24
--- /dev/null
+++ b/schedulers/scheduling_unipc_multistep.py
@@ -0,0 +1,615 @@
1# Copyright 2022 TSAIL Team and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
16
17import math
18from typing import List, Optional, Union
19
20import numpy as np
21import torch
22
23from diffusers.configuration_utils import ConfigMixin, register_to_config
24from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
26
27def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
28 """
29 Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
30 (1-beta) over time from t = [0,1].
31
32 Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
33 to that part of the diffusion process.
34
35
36 Args:
37 num_diffusion_timesteps (`int`): the number of betas to produce.
38 max_beta (`float`): the maximum beta to use; use values lower than 1 to
39 prevent singularities.
40
41 Returns:
42 betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
43 """
44
45 def alpha_bar(time_step):
46 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
47
48 betas = []
49 for i in range(num_diffusion_timesteps):
50 t1 = i / num_diffusion_timesteps
51 t2 = (i + 1) / num_diffusion_timesteps
52 betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
53 return torch.tensor(betas, dtype=torch.float32)
54
55
56class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
57 """
58 UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of
59 a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.
60 UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional
61 sampling. It can also be applied to both noise prediction model and data prediction model. The corrector
62 UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy.
63
64 For more details, see the original paper: https://arxiv.org/abs/2302.04867
65
66 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We
67 recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
68
69 We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
70 diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
71 thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
72 stable-diffusion).
73
74 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
75 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
76 [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
77 [`~SchedulerMixin.from_pretrained`] functions.
78
79 Args:
80 num_train_timesteps (`int`): number of diffusion steps used to train the model.
81 beta_start (`float`): the starting `beta` value of inference.
82 beta_end (`float`): the final `beta` value.
83 beta_schedule (`str`):
84 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
85 `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
86 trained_betas (`np.ndarray`, optional):
87 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
88 solver_order (`int`, default `2`):
89 the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of
90 accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided
91 sampling, and `solver_order=3` for unconditional sampling.
92 prediction_type (`str`, default `epsilon`, optional):
93 prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
94 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
95 https://imagen.research.google/video/paper.pdf)
96 thresholding (`bool`, default `False`):
97 whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
98 For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
99 use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
100 models (such as stable-diffusion).
101 dynamic_thresholding_ratio (`float`, default `0.995`):
102 the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
103 (https://arxiv.org/abs/2205.11487).
104 sample_max_value (`float`, default `1.0`):
105 the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
106 `predict_x0=True`.
107 predict_x0 (`bool`, default `True`):
108 whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details
109 solver_type (`str`, default `bh1`):
110 the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use
111 `bh2` otherwise.
112 lower_order_final (`bool`, default `True`):
113 whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
114 find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
115 disable_corrector (`list`, default `[]`):
116 decide which step to disable the corrector. For large guidance scale, the misalignment between the
117 `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be
118 mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0])
119 solver_p (`SchedulerMixin`):
120 can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
121 """
122
123 _compatibles = [e.name for e in KarrasDiffusionSchedulers]
124 order = 1
125
126 @register_to_config
127 def __init__(
128 self,
129 num_train_timesteps: int = 1000,
130 beta_start: float = 0.0001,
131 beta_end: float = 0.02,
132 beta_schedule: str = "linear",
133 trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
134 solver_order: int = 2,
135 prediction_type: str = "epsilon",
136 thresholding: bool = False,
137 dynamic_thresholding_ratio: float = 0.995,
138 sample_max_value: float = 1.0,
139 predict_x0: bool = True,
140 solver_type: str = "bh1",
141 lower_order_final: bool = True,
142 disable_corrector: List[int] = [],
143 solver_p: SchedulerMixin = None,
144 ):
145 if trained_betas is not None:
146 self.betas = torch.tensor(trained_betas, dtype=torch.float32)
147 elif beta_schedule == "linear":
148 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
149 elif beta_schedule == "scaled_linear":
150 # this schedule is very specific to the latent diffusion model.
151 self.betas = (
152 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
153 )
154 elif beta_schedule == "squaredcos_cap_v2":
155 # Glide cosine schedule
156 self.betas = betas_for_alpha_bar(num_train_timesteps)
157 else:
158 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
159
160 self.alphas = 1.0 - self.betas
161 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
162 # Currently we only support VP-type noise schedule
163 self.alpha_t = torch.sqrt(self.alphas_cumprod)
164 self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
165 self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
166
167 # standard deviation of the initial noise distribution
168 self.init_noise_sigma = 1.0
169
170 if solver_type not in ["bh1", "bh2"]:
171 raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
172
173 self.predict_x0 = predict_x0
174 # setable values
175 self.num_inference_steps = None
176 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
177 self.timesteps = torch.from_numpy(timesteps)
178 self.model_outputs = [None] * solver_order
179 self.timestep_list = [None] * solver_order
180 self.lower_order_nums = 0
181 self.disable_corrector = disable_corrector
182 self.solver_p = solver_p
183
184 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
185 """
186 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
187
188 Args:
189 num_inference_steps (`int`):
190 the number of diffusion steps used when generating samples with a pre-trained model.
191 device (`str` or `torch.device`, optional):
192 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
193 """
194 self.num_inference_steps = num_inference_steps
195 timesteps = (
196 np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
197 .round()[::-1][:-1]
198 .copy()
199 .astype(np.int64)
200 )
201 self.timesteps = torch.from_numpy(timesteps).to(device)
202 self.model_outputs = [
203 None,
204 ] * self.config.solver_order
205 self.lower_order_nums = 0
206 if self.solver_p:
207 self.solver_p.set_timesteps(num_inference_steps, device=device)
208
209 def convert_model_output(
210 self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
211 ):
212 r"""
213 Convert the model output to the corresponding type that the algorithm PC needs.
214
215 Args:
216 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
217 timestep (`int`): current discrete timestep in the diffusion chain.
218 sample (`torch.FloatTensor`):
219 current instance of sample being created by diffusion process.
220
221 Returns:
222 `torch.FloatTensor`: the converted model output.
223 """
224 if self.predict_x0:
225 if self.config.prediction_type == "epsilon":
226 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
227 x0_pred = (sample - sigma_t * model_output) / alpha_t
228 elif self.config.prediction_type == "sample":
229 x0_pred = model_output
230 elif self.config.prediction_type == "v_prediction":
231 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
232 x0_pred = alpha_t * sample - sigma_t * model_output
233 else:
234 raise ValueError(
235 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
236 " `v_prediction` for the DPMSolverMultistepScheduler."
237 )
238
239 if self.config.thresholding:
240 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
241 orig_dtype = x0_pred.dtype
242 if orig_dtype not in [torch.float, torch.double]:
243 x0_pred = x0_pred.float()
244 dynamic_max_val = torch.quantile(
245 torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
246 )
247 dynamic_max_val = torch.maximum(
248 dynamic_max_val,
249 self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
250 )[(...,) + (None,) * (x0_pred.ndim - 1)]
251 x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
252 x0_pred = x0_pred.type(orig_dtype)
253 return x0_pred
254 else:
255 if self.config.prediction_type == "epsilon":
256 return model_output
257 elif self.config.prediction_type == "sample":
258 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
259 epsilon = (sample - alpha_t * model_output) / sigma_t
260 return epsilon
261 elif self.config.prediction_type == "v_prediction":
262 alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
263 epsilon = alpha_t * model_output + sigma_t * sample
264 return epsilon
265 else:
266 raise ValueError(
267 f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
268 " `v_prediction` for the DPMSolverMultistepScheduler."
269 )
270
271 def multistep_uni_p_bh_update(
272 self,
273 model_output: torch.FloatTensor,
274 prev_timestep: int,
275 sample: torch.FloatTensor,
276 order: int,
277 ):
278 """
279 One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
280
281 Args:
282 model_output (`torch.FloatTensor`):
283 direct outputs from learned diffusion model at the current timestep.
284 prev_timestep (`int`): previous discrete timestep in the diffusion chain.
285 sample (`torch.FloatTensor`):
286 current instance of sample being created by diffusion process.
287 order (`int`): the order of UniP at this step, also the p in UniPC-p.
288
289 Returns:
290 `torch.FloatTensor`: the sample tensor at the previous timestep.
291 """
292 timestep_list = self.timestep_list
293 model_output_list = self.model_outputs
294
295 s0, t = self.timestep_list[-1], prev_timestep
296 m0 = model_output_list[-1]
297 x = sample
298
299 if self.solver_p:
300 x_t = self.solver_p.step(model_output, s0, x).prev_sample
301 return x_t
302
303 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
304 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
305 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
306
307 h = lambda_t - lambda_s0
308 device = sample.device
309
310 rks = []
311 D1s = []
312 for i in range(1, order):
313 si = timestep_list[-(i + 1)]
314 mi = model_output_list[-(i + 1)]
315 lambda_si = self.lambda_t[si]
316 rk = ((lambda_si - lambda_s0) / h)
317 rks.append(rk)
318 D1s.append((mi - m0) / rk)
319
320 rks.append(1.)
321 rks = torch.tensor(rks, device=device)
322
323 R = []
324 b = []
325
326 hh = -h if self.predict_x0 else h
327 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
328 h_phi_k = h_phi_1 / hh - 1
329
330 factorial_i = 1
331
332 if self.config.solver_type == 'bh1':
333 B_h = hh
334 elif self.config.solver_type == 'bh2':
335 B_h = torch.expm1(hh)
336 else:
337 raise NotImplementedError()
338
339 for i in range(1, order + 1):
340 R.append(torch.pow(rks, i - 1))
341 b.append(h_phi_k * factorial_i / B_h)
342 factorial_i *= (i + 1)
343 h_phi_k = h_phi_k / hh - 1 / factorial_i
344
345 R = torch.stack(R)
346 b = torch.tensor(b, device=device)
347
348 if len(D1s) > 0:
349 D1s = torch.stack(D1s, dim=1) # (B, K)
350 # for order 2, we use a simplified version
351 if order == 2:
352 rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
353 else:
354 rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
355 else:
356 D1s = None
357
358 if self.predict_x0:
359 x_t_ = (
360 sigma_t / sigma_s0 * x
361 - alpha_t * h_phi_1 * m0
362 )
363 if D1s is not None:
364 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
365 else:
366 pred_res = 0
367 x_t = x_t_ - alpha_t * B_h * pred_res
368 else:
369 x_t_ = (
370 alpha_t / alpha_s0 * x
371 - sigma_t * h_phi_1 * m0
372 )
373 if D1s is not None:
374 pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
375 else:
376 pred_res = 0
377 x_t = x_t_ - sigma_t * B_h * pred_res
378
379 x_t = x_t.to(x.dtype)
380 return x_t
381
382 def multistep_uni_c_bh_update(
383 self,
384 this_model_output: torch.FloatTensor,
385 this_timestep: int,
386 last_sample: torch.FloatTensor,
387 this_sample: torch.FloatTensor,
388 order: int,
389 ):
390 """
391 One step for the UniC (B(h) version).
392
393 Args:
394 this_model_output (`torch.FloatTensor`): the model outputs at `x_t`
395 this_timestep (`int`): the current timestep `t`
396 last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}`
397 this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}`
398 order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy
399 should be order + 1
400
401 Returns:
402 `torch.FloatTensor`: the corrected sample tensor at the current timestep.
403 """
404 timestep_list = self.timestep_list
405 model_output_list = self.model_outputs
406
407 s0, t = timestep_list[-1], this_timestep
408 m0 = model_output_list[-1]
409 x = last_sample
410 x_t = this_sample
411 model_t = this_model_output
412
413 lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
414 alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
415 sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
416
417 h = lambda_t - lambda_s0
418 device = this_sample.device
419
420 rks = []
421 D1s = []
422 for i in range(1, order):
423 si = timestep_list[-(i + 1)]
424 mi = model_output_list[-(i + 1)]
425 lambda_si = self.lambda_t[si]
426 rk = ((lambda_si - lambda_s0) / h)
427 rks.append(rk)
428 D1s.append((mi - m0) / rk)
429
430 rks.append(1.)
431 rks = torch.tensor(rks, device=device)
432
433 R = []
434 b = []
435
436 hh = -h if self.predict_x0 else h
437 h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
438 h_phi_k = h_phi_1 / hh - 1
439
440 factorial_i = 1
441
442 if self.config.solver_type == 'bh1':
443 B_h = hh
444 elif self.config.solver_type == 'bh2':
445 B_h = torch.expm1(hh)
446 else:
447 raise NotImplementedError()
448
449 for i in range(1, order + 1):
450 R.append(torch.pow(rks, i - 1))
451 b.append(h_phi_k * factorial_i / B_h)
452 factorial_i *= (i + 1)
453 h_phi_k = h_phi_k / hh - 1 / factorial_i
454
455 R = torch.stack(R)
456 b = torch.tensor(b, device=device)
457
458 if len(D1s) > 0:
459 D1s = torch.stack(D1s, dim=1)
460 else:
461 D1s = None
462
463 # for order 1, we use a simplified version
464 if order == 1:
465 rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
466 else:
467 rhos_c = torch.linalg.solve(R, b)
468
469 if self.predict_x0:
470 x_t_ = (
471 sigma_t / sigma_s0 * x
472 - alpha_t * h_phi_1 * m0
473 )
474 if D1s is not None:
475 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
476 else:
477 corr_res = 0
478 D1_t = (model_t - m0)
479 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
480 else:
481 x_t_ = (
482 alpha_t / alpha_s0 * x
483 - sigma_t * h_phi_1 * m0
484 )
485 if D1s is not None:
486 corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
487 else:
488 corr_res = 0
489 D1_t = (model_t - m0)
490 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
491 x_t = x_t.to(x.dtype)
492 return x_t
493
494 def step(
495 self,
496 model_output: torch.FloatTensor,
497 timestep: int,
498 sample: torch.FloatTensor,
499 return_dict: bool = True,
500 ):
501 # -> Union[SchedulerOutput, Tuple]:
502 """
503 Step function propagating the sample with the multistep UniPC.
504
505 Args:
506 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
507 timestep (`int`): current discrete timestep in the diffusion chain.
508 sample (`torch.FloatTensor`):
509 current instance of sample being created by diffusion process.
510 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
511
512 Returns:
513 [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
514 True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
515
516 """
517
518 if self.num_inference_steps is None:
519 raise ValueError(
520 "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
521 )
522
523 if isinstance(timestep, torch.Tensor):
524 timestep = timestep.to(self.timesteps.device)
525 step_index = (self.timesteps == timestep).nonzero()
526 if len(step_index) == 0:
527 step_index = len(self.timesteps) - 1
528 else:
529 step_index = step_index.item()
530
531 use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector # step_index not in self.disable_corrector
532
533 model_output_convert = self.convert_model_output(model_output, timestep, sample)
534 if use_corrector:
535 sample = self.multistep_uni_c_bh_update(
536 this_model_output=model_output_convert,
537 this_timestep=timestep,
538 last_sample=self.last_sample,
539 this_sample=sample,
540 order=self.this_order,
541 )
542
543 # now prepare to run the predictor
544 prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
545
546 for i in range(self.config.solver_order - 1):
547 self.model_outputs[i] = self.model_outputs[i + 1]
548 self.timestep_list[i] = self.timestep_list[i + 1]
549
550 self.model_outputs[-1] = model_output_convert
551 self.timestep_list[-1] = timestep
552
553 if self.config.lower_order_final:
554 this_order = min(self.config.solver_order, len(self.timesteps) - step_index)
555 else:
556 this_order = self.config.solver_order
557
558 self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
559 assert self.this_order > 0
560
561 self.last_sample = sample
562 prev_sample = self.multistep_uni_p_bh_update(
563 model_output=model_output, # pass the original non-converted model output, in case solver-p is used
564 prev_timestep=prev_timestep,
565 sample=sample,
566 order=self.this_order,
567 )
568
569 if self.lower_order_nums < self.config.solver_order:
570 self.lower_order_nums += 1
571
572 if not return_dict:
573 return (prev_sample,)
574
575 return SchedulerOutput(prev_sample=prev_sample)
576
577 def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs): # -> torch.FloatTensor:
578 """
579 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
580 current timestep.
581
582 Args:
583 sample (`torch.FloatTensor`): input sample
584
585 Returns:
586 `torch.FloatTensor`: scaled input sample
587 """
588 return sample
589
590 def add_noise(
591 self,
592 original_samples: torch.FloatTensor,
593 noise: torch.FloatTensor,
594 timesteps: torch.IntTensor,
595 ):
596 # -> torch.FloatTensor:
597 # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
598 self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
599 timesteps = timesteps.to(original_samples.device)
600
601 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
602 sqrt_alpha_prod = sqrt_alpha_prod.flatten()
603 while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
604 sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
605
606 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
607 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
608 while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
609 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
610
611 noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
612 return noisy_samples
613
614 def __len__(self):
615 return self.config.num_train_timesteps
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4c1ec31..5a7911c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -375,7 +375,7 @@ def parse_args():
375 parser.add_argument( 375 parser.add_argument(
376 "--sample_steps", 376 "--sample_steps",
377 type=int, 377 type=int,
378 default=20, 378 default=10,
379 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 379 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
380 ) 380 )
381 parser.add_argument( 381 parser.add_argument(
@@ -511,6 +511,7 @@ def main():
511 dtype=weight_dtype, 511 dtype=weight_dtype,
512 with_prior_preservation=args.num_class_images != 0, 512 with_prior_preservation=args.num_class_images != 0,
513 prior_loss_weight=args.prior_loss_weight, 513 prior_loss_weight=args.prior_loss_weight,
514 no_val=args.valid_set_size == 0,
514 ) 515 )
515 516
516 checkpoint_output_dir = output_dir / "model" 517 checkpoint_output_dir = output_dir / "model"
diff --git a/train_lora.py b/train_lora.py
index a8c1cf6..330bcd6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -335,7 +335,7 @@ def parse_args():
335 parser.add_argument( 335 parser.add_argument(
336 "--sample_steps", 336 "--sample_steps",
337 type=int, 337 type=int,
338 default=20, 338 default=10,
339 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 339 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
340 ) 340 )
341 parser.add_argument( 341 parser.add_argument(
@@ -487,6 +487,7 @@ def main():
487 dtype=weight_dtype, 487 dtype=weight_dtype,
488 with_prior_preservation=args.num_class_images != 0, 488 with_prior_preservation=args.num_class_images != 0,
489 prior_loss_weight=args.prior_loss_weight, 489 prior_loss_weight=args.prior_loss_weight,
490 no_val=args.valid_set_size == 0,
490 ) 491 )
491 492
492 checkpoint_output_dir = output_dir / "model" 493 checkpoint_output_dir = output_dir / "model"
diff --git a/train_ti.py b/train_ti.py
index f78c7d2..d1defb3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -392,7 +392,7 @@ def parse_args():
392 parser.add_argument( 392 parser.add_argument(
393 "--sample_steps", 393 "--sample_steps",
394 type=int, 394 type=int,
395 default=20, 395 default=10,
396 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 396 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
397 ) 397 )
398 parser.add_argument( 398 parser.add_argument(
@@ -586,6 +586,7 @@ def main():
586 seed=args.seed, 586 seed=args.seed,
587 with_prior_preservation=args.num_class_images != 0, 587 with_prior_preservation=args.num_class_images != 0,
588 prior_loss_weight=args.prior_loss_weight, 588 prior_loss_weight=args.prior_loss_weight,
589 no_val=args.valid_set_size == 0,
589 low_freq_noise=0, 590 low_freq_noise=0,
590 strategy=textual_inversion_strategy, 591 strategy=textual_inversion_strategy,
591 num_train_epochs=args.num_train_epochs, 592 num_train_epochs=args.num_train_epochs,
diff --git a/training/functional.py b/training/functional.py
index e1035ce..b7ea90d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
25from training.util import AverageMeter 26from training.util import AverageMeter
26 27
27 28
@@ -79,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str):
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
80 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
81 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 82 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
82 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( 83 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder='scheduler')
84 85
85 embeddings = patch_managed_embeddings(text_encoder) 86 embeddings = patch_managed_embeddings(text_encoder)
@@ -93,7 +94,7 @@ def save_samples(
93 text_encoder: CLIPTextModel, 94 text_encoder: CLIPTextModel,
94 tokenizer: MultiCLIPTokenizer, 95 tokenizer: MultiCLIPTokenizer,
95 vae: AutoencoderKL, 96 vae: AutoencoderKL,
96 sample_scheduler: DPMSolverMultistepScheduler, 97 sample_scheduler: UniPCMultistepScheduler,
97 train_dataloader: DataLoader, 98 train_dataloader: DataLoader,
98 val_dataloader: Optional[DataLoader], 99 val_dataloader: Optional[DataLoader],
99 output_dir: Path, 100 output_dir: Path,
@@ -180,7 +181,7 @@ def generate_class_images(
180 vae: AutoencoderKL, 181 vae: AutoencoderKL,
181 unet: UNet2DConditionModel, 182 unet: UNet2DConditionModel,
182 tokenizer: MultiCLIPTokenizer, 183 tokenizer: MultiCLIPTokenizer,
183 sample_scheduler: DPMSolverMultistepScheduler, 184 sample_scheduler: UniPCMultistepScheduler,
184 train_dataset: VlpnDataset, 185 train_dataset: VlpnDataset,
185 sample_batch_size: int, 186 sample_batch_size: int,
186 sample_image_size: int, 187 sample_image_size: int,
@@ -284,6 +285,7 @@ def loss_step(
284 device=latents.device, 285 device=latents.device,
285 generator=generator 286 generator=generator
286 ) 287 )
288
287 bsz = latents.shape[0] 289 bsz = latents.shape[0]
288 # Sample a random timestep for each image 290 # Sample a random timestep for each image
289 timesteps = torch.randint( 291 timesteps = torch.randint(
@@ -351,6 +353,7 @@ def train_loop(
351 train_dataloader: DataLoader, 353 train_dataloader: DataLoader,
352 val_dataloader: Optional[DataLoader], 354 val_dataloader: Optional[DataLoader],
353 loss_step: LossCallable, 355 loss_step: LossCallable,
356 no_val: bool = False,
354 sample_frequency: int = 10, 357 sample_frequency: int = 10,
355 checkpoint_frequency: int = 50, 358 checkpoint_frequency: int = 50,
356 global_step_offset: int = 0, 359 global_step_offset: int = 0,
@@ -406,9 +409,15 @@ def train_loop(
406 for epoch in range(num_epochs): 409 for epoch in range(num_epochs):
407 if accelerator.is_main_process: 410 if accelerator.is_main_process:
408 if epoch % sample_frequency == 0: 411 if epoch % sample_frequency == 0:
412 local_progress_bar.clear()
413 global_progress_bar.clear()
414
409 on_sample(global_step + global_step_offset) 415 on_sample(global_step + global_step_offset)
410 416
411 if epoch % checkpoint_frequency == 0 and epoch != 0: 417 if epoch % checkpoint_frequency == 0 and epoch != 0:
418 local_progress_bar.clear()
419 global_progress_bar.clear()
420
412 on_checkpoint(global_step + global_step_offset, "training") 421 on_checkpoint(global_step + global_step_offset, "training")
413 422
414 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 423 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -463,7 +472,7 @@ def train_loop(
463 472
464 on_after_epoch(lr_scheduler.get_last_lr()[0]) 473 on_after_epoch(lr_scheduler.get_last_lr()[0])
465 474
466 if val_dataloader is not None: 475 if val_dataloader is not None and not no_val:
467 model.eval() 476 model.eval()
468 477
469 cur_loss_val = AverageMeter() 478 cur_loss_val = AverageMeter()
@@ -498,11 +507,11 @@ def train_loop(
498 507
499 accelerator.log(logs, step=global_step) 508 accelerator.log(logs, step=global_step)
500 509
501 local_progress_bar.clear()
502 global_progress_bar.clear()
503
504 if accelerator.is_main_process: 510 if accelerator.is_main_process:
505 if avg_acc_val.avg.item() > best_acc_val: 511 if avg_acc_val.avg.item() > best_acc_val:
512 local_progress_bar.clear()
513 global_progress_bar.clear()
514
506 accelerator.print( 515 accelerator.print(
507 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 516 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
508 on_checkpoint(global_step + global_step_offset, "milestone") 517 on_checkpoint(global_step + global_step_offset, "milestone")
@@ -513,6 +522,9 @@ def train_loop(
513 else: 522 else:
514 if accelerator.is_main_process: 523 if accelerator.is_main_process:
515 if avg_acc.avg.item() > best_acc: 524 if avg_acc.avg.item() > best_acc:
525 local_progress_bar.clear()
526 global_progress_bar.clear()
527
516 accelerator.print( 528 accelerator.print(
517 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 529 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
518 on_checkpoint(global_step + global_step_offset, "milestone") 530 on_checkpoint(global_step + global_step_offset, "milestone")
@@ -550,6 +562,7 @@ def train(
550 optimizer: torch.optim.Optimizer, 562 optimizer: torch.optim.Optimizer,
551 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 563 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
552 strategy: TrainingStrategy, 564 strategy: TrainingStrategy,
565 no_val: bool = False,
553 num_train_epochs: int = 100, 566 num_train_epochs: int = 100,
554 sample_frequency: int = 20, 567 sample_frequency: int = 20,
555 checkpoint_frequency: int = 50, 568 checkpoint_frequency: int = 50,
@@ -604,6 +617,7 @@ def train(
604 lr_scheduler=lr_scheduler, 617 lr_scheduler=lr_scheduler,
605 train_dataloader=train_dataloader, 618 train_dataloader=train_dataloader,
606 val_dataloader=val_dataloader, 619 val_dataloader=val_dataloader,
620 no_val=no_val,
607 loss_step=loss_step_, 621 loss_step=loss_step_,
608 sample_frequency=sample_frequency, 622 sample_frequency=sample_frequency,
609 checkpoint_frequency=checkpoint_frequency, 623 checkpoint_frequency=checkpoint_frequency,