summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py154
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py169
-rw-r--r--schedulers/scheduling_euler_a.py6
3 files changed, 113 insertions, 216 deletions
diff --git a/infer.py b/infer.py
index de3d792..40720ea 100644
--- a/infer.py
+++ b/infer.py
@@ -1,18 +1,21 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import sys
5import shlex
6import cmd
4from pathlib import Path 7from pathlib import Path
5from torch import autocast 8from torch import autocast
6import torch 9import torch
7import json 10import json
8from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
9from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 12from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify 13from slugify import slugify
11from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion 14from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
12from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
13 16
14 17
15def parse_args(): 18def create_args_parser():
16 parser = argparse.ArgumentParser( 19 parser = argparse.ArgumentParser(
17 description="Simple example of a training script." 20 description="Simple example of a training script."
18 ) 21 )
@@ -22,6 +25,30 @@ def parse_args():
22 default=None, 25 default=None,
23 ) 26 )
24 parser.add_argument( 27 parser.add_argument(
28 "--scheduler",
29 type=str,
30 choices=["plms", "ddim", "klms", "euler_a"],
31 default="euler_a",
32 )
33 parser.add_argument(
34 "--output_dir",
35 type=str,
36 default="output/inference",
37 )
38 parser.add_argument(
39 "--config",
40 type=str,
41 default=None,
42 )
43
44 return parser
45
46
47def create_cmd_parser():
48 parser = argparse.ArgumentParser(
49 description="Simple example of a training script."
50 )
51 parser.add_argument(
25 "--prompt", 52 "--prompt",
26 type=str, 53 type=str,
27 default=None, 54 default=None,
@@ -49,28 +76,17 @@ def parse_args():
49 parser.add_argument( 76 parser.add_argument(
50 "--batch_num", 77 "--batch_num",
51 type=int, 78 type=int,
52 default=50, 79 default=1,
53 ) 80 )
54 parser.add_argument( 81 parser.add_argument(
55 "--steps", 82 "--steps",
56 type=int, 83 type=int,
57 default=120, 84 default=70,
58 )
59 parser.add_argument(
60 "--scheduler",
61 type=str,
62 choices=["plms", "ddim", "klms", "euler_a"],
63 default="euler_a",
64 ) 85 )
65 parser.add_argument( 86 parser.add_argument(
66 "--guidance_scale", 87 "--guidance_scale",
67 type=int, 88 type=int,
68 default=7.5, 89 default=7,
69 )
70 parser.add_argument(
71 "--clip_guidance_scale",
72 type=int,
73 default=100,
74 ) 90 )
75 parser.add_argument( 91 parser.add_argument(
76 "--seed", 92 "--seed",
@@ -78,21 +94,21 @@ def parse_args():
78 default=torch.random.seed(), 94 default=torch.random.seed(),
79 ) 95 )
80 parser.add_argument( 96 parser.add_argument(
81 "--output_dir",
82 type=str,
83 default="output/inference",
84 )
85 parser.add_argument(
86 "--config", 97 "--config",
87 type=str, 98 type=str,
88 default=None, 99 default=None,
89 ) 100 )
90 101
91 args = parser.parse_args() 102 return parser
103
104
105def run_parser(parser, input=None):
106 args = parser.parse_known_args(input)[0]
107
92 if args.config is not None: 108 if args.config is not None:
93 with open(args.config, 'rt') as f: 109 with open(args.config, 'rt') as f:
94 args = parser.parse_args( 110 args = parser.parse_known_args(
95 namespace=argparse.Namespace(**json.load(f)["args"])) 111 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
96 112
97 return args 113 return args
98 114
@@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}):
104 json.dump(info, f, indent=4) 120 json.dump(info, f, indent=4)
105 121
106 122
107def gen(args, output_dir): 123def create_pipeline(model, scheduler, dtype=torch.bfloat16):
108 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) 124 print("Loading Stable Diffusion pipeline...")
109 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
110 clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
111 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
112 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
113 feature_extractor = CLIPFeatureExtractor.from_pretrained(
114 "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
115 125
116 if args.scheduler == "plms": 126 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype)
127 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype)
128 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype)
129 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype)
130 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype)
131
132 if scheduler == "plms":
117 scheduler = PNDMScheduler( 133 scheduler = PNDMScheduler(
118 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 134 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
119 ) 135 )
120 elif args.scheduler == "klms": 136 elif scheduler == "klms":
121 scheduler = LMSDiscreteScheduler( 137 scheduler = LMSDiscreteScheduler(
122 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 138 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
123 ) 139 )
124 elif args.scheduler == "ddim": 140 elif scheduler == "ddim":
125 scheduler = DDIMScheduler( 141 scheduler = DDIMScheduler(
126 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 142 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
127 ) 143 )
@@ -135,13 +151,24 @@ def gen(args, output_dir):
135 vae=vae, 151 vae=vae,
136 unet=unet, 152 unet=unet,
137 tokenizer=tokenizer, 153 tokenizer=tokenizer,
138 clip_model=clip_model,
139 scheduler=scheduler, 154 scheduler=scheduler,
140 feature_extractor=feature_extractor 155 feature_extractor=feature_extractor
141 ) 156 )
142 pipeline.enable_attention_slicing() 157 pipeline.enable_attention_slicing()
143 pipeline.to("cuda") 158 pipeline.to("cuda")
144 159
160 print("Pipeline loaded.")
161
162 return pipeline
163
164
165def generate(output_dir, pipeline, args):
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True)
169
170 save_args(output_dir, args)
171
145 with autocast("cuda"): 172 with autocast("cuda"):
146 for i in range(args.batch_num): 173 for i in range(args.batch_num):
147 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 174 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
@@ -152,7 +179,6 @@ def gen(args, output_dir):
152 negative_prompt=args.negative_prompt, 179 negative_prompt=args.negative_prompt,
153 num_inference_steps=args.steps, 180 num_inference_steps=args.steps,
154 guidance_scale=args.guidance_scale, 181 guidance_scale=args.guidance_scale,
155 clip_guidance_scale=args.clip_guidance_scale,
156 generator=generator, 182 generator=generator,
157 ).images 183 ).images
158 184
@@ -160,18 +186,56 @@ def gen(args, output_dir):
160 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 186 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
161 187
162 188
163def main(): 189class CmdParse(cmd.Cmd):
164 args = parse_args() 190 prompt = 'dream> '
191 commands = []
165 192
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 193 def __init__(self, output_dir, pipeline, parser):
167 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") 194 super().__init__()
168 output_dir.mkdir(parents=True, exist_ok=True)
169 195
170 save_args(output_dir, args) 196 self.output_dir = output_dir
197 self.pipeline = pipeline
198 self.parser = parser
199
200 def default(self, line):
201 line = line.replace("'", "\\'")
202
203 try:
204 elements = shlex.split(line)
205 except ValueError as e:
206 print(str(e))
207
208 if elements[0] == 'q':
209 return True
210
211 try:
212 args = run_parser(self.parser, elements)
213 except SystemExit:
214 self.parser.print_help()
215
216 if len(args.prompt) == 0:
217 print('Try again with a prompt!')
218
219 try:
220 generate(self.output_dir, self.pipeline, args)
221 except KeyboardInterrupt:
222 print('Generation cancelled.')
223
224 def do_exit(self, line):
225 return True
226
227
228def main():
229 logging.basicConfig(stream=sys.stdout, level=logging.WARN)
171 230
172 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 231 args_parser = create_args_parser()
232 args = run_parser(args_parser)
233 output_dir = Path(args.output_dir)
173 234
174 gen(args, output_dir) 235 pipeline = create_pipeline(args.model, args.scheduler)
236 cmd_parser = create_cmd_parser()
237 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
238 cmd_prompt.cmdloop()
175 239
176 240
177if __name__ == "__main__": 241if __name__ == "__main__":
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
index 306d9a9..ddf7ce1 100644
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
@@ -17,53 +17,14 @@ from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18 18
19 19
20class MakeCutouts(nn.Module):
21 def __init__(self, cut_size, cut_power=1.0):
22 super().__init__()
23
24 self.cut_size = cut_size
25 self.cut_power = cut_power
26
27 def forward(self, pixel_values, num_cutouts):
28 sideY, sideX = pixel_values.shape[2:4]
29 max_size = min(sideX, sideY)
30 min_size = min(sideX, sideY, self.cut_size)
31 cutouts = []
32 for _ in range(num_cutouts):
33 size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
34 offsetx = torch.randint(0, sideX - size + 1, ())
35 offsety = torch.randint(0, sideY - size + 1, ())
36 cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size]
37 cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
38 return torch.cat(cutouts)
39
40
41def spherical_dist_loss(x, y):
42 x = F.normalize(x, dim=-1)
43 y = F.normalize(y, dim=-1)
44 return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
45
46
47def set_requires_grad(model, value):
48 for param in model.parameters():
49 param.requires_grad = value
50
51
52class CLIPGuidedStableDiffusion(DiffusionPipeline): 20class CLIPGuidedStableDiffusion(DiffusionPipeline):
53 """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
54 - https://github.com/Jack000/glid-3-xl
55 - https://github.dev/crowsonkb/k-diffusion
56 """
57
58 def __init__( 21 def __init__(
59 self, 22 self,
60 vae: AutoencoderKL, 23 vae: AutoencoderKL,
61 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
62 clip_model: CLIPModel,
63 tokenizer: CLIPTokenizer, 25 tokenizer: CLIPTokenizer,
64 unet: UNet2DConditionModel, 26 unet: UNet2DConditionModel,
65 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 27 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler],
66 feature_extractor: CLIPFeatureExtractor,
67 **kwargs, 28 **kwargs,
68 ): 29 ):
69 super().__init__() 30 super().__init__()
@@ -85,19 +46,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
85 self.register_modules( 46 self.register_modules(
86 vae=vae, 47 vae=vae,
87 text_encoder=text_encoder, 48 text_encoder=text_encoder,
88 clip_model=clip_model,
89 tokenizer=tokenizer, 49 tokenizer=tokenizer,
90 unet=unet, 50 unet=unet,
91 scheduler=scheduler, 51 scheduler=scheduler,
92 feature_extractor=feature_extractor,
93 ) 52 )
94 53
95 self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
96 self.make_cutouts = MakeCutouts(feature_extractor.size)
97
98 set_requires_grad(self.text_encoder, False)
99 set_requires_grad(self.clip_model, False)
100
101 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 54 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
102 r""" 55 r"""
103 Enable sliced attention computation. 56 Enable sliced attention computation.
@@ -125,87 +78,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
125 # set slice_size = `None` to disable `attention slicing` 78 # set slice_size = `None` to disable `attention slicing`
126 self.enable_attention_slicing(None) 79 self.enable_attention_slicing(None)
127 80
128 def freeze_vae(self):
129 set_requires_grad(self.vae, False)
130
131 def unfreeze_vae(self):
132 set_requires_grad(self.vae, True)
133
134 def freeze_unet(self):
135 set_requires_grad(self.unet, False)
136
137 def unfreeze_unet(self):
138 set_requires_grad(self.unet, True)
139
140 @torch.enable_grad()
141 def cond_fn(
142 self,
143 latents,
144 timestep,
145 index,
146 text_embeddings,
147 noise_pred_original,
148 text_embeddings_clip,
149 clip_guidance_scale,
150 num_cutouts,
151 use_cutouts=True,
152 ):
153 latents = latents.detach().requires_grad_()
154
155 if isinstance(self.scheduler, LMSDiscreteScheduler):
156 sigma = self.scheduler.sigmas[index]
157 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
158 latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
159 else:
160 latent_model_input = latents
161
162 # predict the noise residual
163 noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
164
165 if isinstance(self.scheduler, PNDMScheduler):
166 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
167 beta_prod_t = 1 - alpha_prod_t
168 # compute predicted original sample from predicted noise also called
169 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
170 pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
171
172 fac = torch.sqrt(beta_prod_t)
173 sample = pred_original_sample * (fac) + latents * (1 - fac)
174 elif isinstance(self.scheduler, LMSDiscreteScheduler):
175 sigma = self.scheduler.sigmas[index]
176 sample = latents - sigma * noise_pred
177 else:
178 raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
179
180 sample = 1 / 0.18215 * sample
181 image = self.vae.decode(sample).sample
182 image = (image / 2 + 0.5).clamp(0, 1)
183
184 if use_cutouts:
185 image = self.make_cutouts(image, num_cutouts)
186 else:
187 image = transforms.Resize(self.feature_extractor.size)(image)
188 image = self.normalize(image)
189
190 image_embeddings_clip = self.clip_model.get_image_features(image).float()
191 image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
192
193 if use_cutouts:
194 dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
195 dists = dists.view([num_cutouts, sample.shape[0], -1])
196 loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
197 else:
198 loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
199
200 grads = -torch.autograd.grad(loss, latents)[0]
201
202 if isinstance(self.scheduler, LMSDiscreteScheduler):
203 latents = latents.detach() + grads * (sigma**2)
204 noise_pred = noise_pred_original
205 else:
206 noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
207 return noise_pred, latents
208
209 @torch.no_grad() 81 @torch.no_grad()
210 def __call__( 82 def __call__(
211 self, 83 self,
@@ -216,10 +88,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
216 num_inference_steps: Optional[int] = 50, 88 num_inference_steps: Optional[int] = 50,
217 guidance_scale: Optional[float] = 7.5, 89 guidance_scale: Optional[float] = 7.5,
218 eta: Optional[float] = 0.0, 90 eta: Optional[float] = 0.0,
219 clip_guidance_scale: Optional[float] = 100,
220 clip_prompt: Optional[Union[str, List[str]]] = None,
221 num_cutouts: Optional[int] = 4,
222 use_cutouts: Optional[bool] = True,
223 generator: Optional[torch.Generator] = None, 91 generator: Optional[torch.Generator] = None,
224 latents: Optional[torch.FloatTensor] = None, 92 latents: Optional[torch.FloatTensor] = None,
225 output_type: Optional[str] = "pil", 93 output_type: Optional[str] = "pil",
@@ -305,24 +173,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
305 "The following part of your input was truncated because CLIP can only handle sequences up to" 173 "The following part of your input was truncated because CLIP can only handle sequences up to"
306 f" {self.tokenizer.model_max_length} tokens: {removed_text}" 174 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
307 ) 175 )
176 print(f"Too many tokens: {removed_text}")
308 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 177 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
309 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 178 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
310 179
311 if clip_guidance_scale > 0:
312 if clip_prompt is not None:
313 clip_text_inputs = self.tokenizer(
314 clip_prompt,
315 padding="max_length",
316 max_length=self.tokenizer.model_max_length,
317 truncation=True,
318 return_tensors="pt",
319 )
320 clip_text_input_ids = clip_text_inputs.input_ids
321 else:
322 clip_text_input_ids = text_input_ids
323 text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device))
324 text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
325
326 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 180 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
327 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 181 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
328 # corresponds to doing no classifier free guidance. 182 # corresponds to doing no classifier free guidance.
@@ -357,7 +211,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
357 else: 211 else:
358 if latents.shape != latents_shape: 212 if latents.shape != latents_shape:
359 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 213 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
360 latents = latents.to(self.device) 214 latents = latents.to(self.device)
361 215
362 # set timesteps 216 # set timesteps
363 self.scheduler.set_timesteps(num_inference_steps) 217 self.scheduler.set_timesteps(num_inference_steps)
@@ -414,23 +268,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
414 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 268 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
415 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 269 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
416 270
417 # perform clip guidance
418 if clip_guidance_scale > 0:
419 text_embeddings_for_guidance = (
420 text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
421 )
422 noise_pred, latents = self.cond_fn(
423 latents,
424 t,
425 i,
426 text_embeddings_for_guidance,
427 noise_pred,
428 text_embeddings_clip,
429 clip_guidance_scale,
430 num_cutouts,
431 use_cutouts,
432 )
433
434 # compute the previous noisy sample x_t -> x_t-1 271 # compute the previous noisy sample x_t -> x_t-1
435 if isinstance(self.scheduler, LMSDiscreteScheduler): 272 if isinstance(self.scheduler, LMSDiscreteScheduler):
436 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 273 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 57a56de..29ebd07 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
216 216
217 self.num_inference_steps = num_inference_steps 217 self.num_inference_steps = num_inference_steps
218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device)
220 self.timesteps = self.sigmas 220 self.timesteps = self.sigmas
221 221
222 def add_noise_to_input( 222 def add_noise_to_input(
@@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
272 """ 272 """
273 latents = sample 273 latents = sample
274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) 274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
275
276 # if callback is not None:
277 # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
278 d = to_d(latents, timestep, model_output) 275 d = to_d(latents, timestep, model_output)
279 # Euler method
280 dt = sigma_down - timestep 276 dt = sigma_down - timestep
281 latents = latents + d * dt 277 latents = latents + d * dt
282 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, 278 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,