diff options
author | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-02 12:56:58 +0200 |
commit | 49de8142f523aef3f6adfd0c33a9a160aa7400c0 (patch) | |
tree | 3638e8ca449bc18acf947ebc0cbc2ee4ecf18a61 | |
parent | Fix seed, better progress bar, fix euler_a for batch size > 1 (diff) | |
download | textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.gz textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.tar.bz2 textual-inversion-diff-49de8142f523aef3f6adfd0c33a9a160aa7400c0.zip |
WIP: img2img
-rw-r--r-- | infer.py | 90 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py (renamed from pipelines/stable_diffusion/clip_guided_stable_diffusion.py) | 80 |
2 files changed, 131 insertions, 39 deletions
@@ -8,13 +8,47 @@ from pathlib import Path | |||
8 | from torch import autocast | 8 | from torch import autocast |
9 | import torch | 9 | import torch |
10 | import json | 10 | import json |
11 | from PIL import Image | ||
11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
13 | from slugify import slugify | 14 | from slugify import slugify |
14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
15 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
16 | 17 | ||
17 | 18 | ||
19 | default_args = { | ||
20 | "model": None, | ||
21 | "scheduler": "euler_a", | ||
22 | "output_dir": "output/inference", | ||
23 | "config": None, | ||
24 | } | ||
25 | |||
26 | |||
27 | default_cmds = { | ||
28 | "prompt": None, | ||
29 | "negative_prompt": None, | ||
30 | "image": None, | ||
31 | "image_strength": .7, | ||
32 | "width": 512, | ||
33 | "height": 512, | ||
34 | "batch_size": 1, | ||
35 | "batch_num": 1, | ||
36 | "steps": 50, | ||
37 | "guidance_scale": 7.0, | ||
38 | "seed": None, | ||
39 | "config": None, | ||
40 | } | ||
41 | |||
42 | |||
43 | def merge_dicts(d1, *args): | ||
44 | d1 = d1.copy() | ||
45 | |||
46 | for d in args: | ||
47 | d1.update({k: v for (k, v) in d.items() if v is not None}) | ||
48 | |||
49 | return d1 | ||
50 | |||
51 | |||
18 | def create_args_parser(): | 52 | def create_args_parser(): |
19 | parser = argparse.ArgumentParser( | 53 | parser = argparse.ArgumentParser( |
20 | description="Simple example of a training script." | 54 | description="Simple example of a training script." |
@@ -22,23 +56,19 @@ def create_args_parser(): | |||
22 | parser.add_argument( | 56 | parser.add_argument( |
23 | "--model", | 57 | "--model", |
24 | type=str, | 58 | type=str, |
25 | default=None, | ||
26 | ) | 59 | ) |
27 | parser.add_argument( | 60 | parser.add_argument( |
28 | "--scheduler", | 61 | "--scheduler", |
29 | type=str, | 62 | type=str, |
30 | choices=["plms", "ddim", "klms", "euler_a"], | 63 | choices=["plms", "ddim", "klms", "euler_a"], |
31 | default="euler_a", | ||
32 | ) | 64 | ) |
33 | parser.add_argument( | 65 | parser.add_argument( |
34 | "--output_dir", | 66 | "--output_dir", |
35 | type=str, | 67 | type=str, |
36 | default="output/inference", | ||
37 | ) | 68 | ) |
38 | parser.add_argument( | 69 | parser.add_argument( |
39 | "--config", | 70 | "--config", |
40 | type=str, | 71 | type=str, |
41 | default=None, | ||
42 | ) | 72 | ) |
43 | 73 | ||
44 | return parser | 74 | return parser |
@@ -51,66 +81,69 @@ def create_cmd_parser(): | |||
51 | parser.add_argument( | 81 | parser.add_argument( |
52 | "--prompt", | 82 | "--prompt", |
53 | type=str, | 83 | type=str, |
54 | default=None, | ||
55 | ) | 84 | ) |
56 | parser.add_argument( | 85 | parser.add_argument( |
57 | "--negative_prompt", | 86 | "--negative_prompt", |
58 | type=str, | 87 | type=str, |
59 | default=None, | 88 | ) |
89 | parser.add_argument( | ||
90 | "--image", | ||
91 | type=str, | ||
92 | ) | ||
93 | parser.add_argument( | ||
94 | "--image_strength", | ||
95 | type=float, | ||
60 | ) | 96 | ) |
61 | parser.add_argument( | 97 | parser.add_argument( |
62 | "--width", | 98 | "--width", |
63 | type=int, | 99 | type=int, |
64 | default=512, | ||
65 | ) | 100 | ) |
66 | parser.add_argument( | 101 | parser.add_argument( |
67 | "--height", | 102 | "--height", |
68 | type=int, | 103 | type=int, |
69 | default=512, | ||
70 | ) | 104 | ) |
71 | parser.add_argument( | 105 | parser.add_argument( |
72 | "--batch_size", | 106 | "--batch_size", |
73 | type=int, | 107 | type=int, |
74 | default=1, | ||
75 | ) | 108 | ) |
76 | parser.add_argument( | 109 | parser.add_argument( |
77 | "--batch_num", | 110 | "--batch_num", |
78 | type=int, | 111 | type=int, |
79 | default=1, | ||
80 | ) | 112 | ) |
81 | parser.add_argument( | 113 | parser.add_argument( |
82 | "--steps", | 114 | "--steps", |
83 | type=int, | 115 | type=int, |
84 | default=70, | ||
85 | ) | 116 | ) |
86 | parser.add_argument( | 117 | parser.add_argument( |
87 | "--guidance_scale", | 118 | "--guidance_scale", |
88 | type=int, | 119 | type=float, |
89 | default=7, | ||
90 | ) | 120 | ) |
91 | parser.add_argument( | 121 | parser.add_argument( |
92 | "--seed", | 122 | "--seed", |
93 | type=int, | 123 | type=int, |
94 | default=None, | ||
95 | ) | 124 | ) |
96 | parser.add_argument( | 125 | parser.add_argument( |
97 | "--config", | 126 | "--config", |
98 | type=str, | 127 | type=str, |
99 | default=None, | ||
100 | ) | 128 | ) |
101 | 129 | ||
102 | return parser | 130 | return parser |
103 | 131 | ||
104 | 132 | ||
105 | def run_parser(parser, input=None): | 133 | def run_parser(parser, defaults, input=None): |
106 | args = parser.parse_known_args(input)[0] | 134 | args = parser.parse_known_args(input)[0] |
135 | conf_args = argparse.Namespace() | ||
107 | 136 | ||
108 | if args.config is not None: | 137 | if args.config is not None: |
109 | with open(args.config, 'rt') as f: | 138 | with open(args.config, 'rt') as f: |
110 | args = parser.parse_known_args( | 139 | conf_args = parser.parse_known_args( |
111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | 140 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
112 | 141 | ||
113 | return args | 142 | res = defaults.copy() |
143 | for dict in [vars(conf_args), vars(args)]: | ||
144 | res.update({k: v for (k, v) in dict.items() if v is not None}) | ||
145 | |||
146 | return argparse.Namespace(**res) | ||
114 | 147 | ||
115 | 148 | ||
116 | def save_args(basepath, args, extra={}): | 149 | def save_args(basepath, args, extra={}): |
@@ -146,7 +179,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
146 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 179 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
147 | ) | 180 | ) |
148 | 181 | ||
149 | pipeline = CLIPGuidedStableDiffusion( | 182 | pipeline = VlpnStableDiffusion( |
150 | text_encoder=text_encoder, | 183 | text_encoder=text_encoder, |
151 | vae=vae, | 184 | vae=vae, |
152 | unet=unet, | 185 | unet=unet, |
@@ -154,7 +187,7 @@ def create_pipeline(model, scheduler, dtype=torch.bfloat16): | |||
154 | scheduler=scheduler, | 187 | scheduler=scheduler, |
155 | feature_extractor=feature_extractor | 188 | feature_extractor=feature_extractor |
156 | ) | 189 | ) |
157 | pipeline.enable_attention_slicing() | 190 | # pipeline.enable_attention_slicing() |
158 | pipeline.to("cuda") | 191 | pipeline.to("cuda") |
159 | 192 | ||
160 | print("Pipeline loaded.") | 193 | print("Pipeline loaded.") |
@@ -171,6 +204,13 @@ def generate(output_dir, pipeline, args): | |||
171 | 204 | ||
172 | save_args(output_dir, args) | 205 | save_args(output_dir, args) |
173 | 206 | ||
207 | if args.image: | ||
208 | init_image = Image.open(args.image) | ||
209 | if not init_image.mode == "RGB": | ||
210 | init_image = init_image.convert("RGB") | ||
211 | else: | ||
212 | init_image = None | ||
213 | |||
174 | with autocast("cuda"): | 214 | with autocast("cuda"): |
175 | for i in range(args.batch_num): | 215 | for i in range(args.batch_num): |
176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") | 216 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") |
@@ -184,11 +224,15 @@ def generate(output_dir, pipeline, args): | |||
184 | num_inference_steps=args.steps, | 224 | num_inference_steps=args.steps, |
185 | guidance_scale=args.guidance_scale, | 225 | guidance_scale=args.guidance_scale, |
186 | generator=generator, | 226 | generator=generator, |
227 | latents=init_image, | ||
187 | ).images | 228 | ).images |
188 | 229 | ||
189 | for j, image in enumerate(images): | 230 | for j, image in enumerate(images): |
190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 231 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
191 | 232 | ||
233 | if torch.cuda.is_available(): | ||
234 | torch.cuda.empty_cache() | ||
235 | |||
192 | 236 | ||
193 | class CmdParse(cmd.Cmd): | 237 | class CmdParse(cmd.Cmd): |
194 | prompt = 'dream> ' | 238 | prompt = 'dream> ' |
@@ -213,7 +257,7 @@ class CmdParse(cmd.Cmd): | |||
213 | return True | 257 | return True |
214 | 258 | ||
215 | try: | 259 | try: |
216 | args = run_parser(self.parser, elements) | 260 | args = run_parser(self.parser, default_cmds, elements) |
217 | except SystemExit: | 261 | except SystemExit: |
218 | self.parser.print_help() | 262 | self.parser.print_help() |
219 | 263 | ||
@@ -233,7 +277,7 @@ def main(): | |||
233 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | 277 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) |
234 | 278 | ||
235 | args_parser = create_args_parser() | 279 | args_parser = create_args_parser() |
236 | args = run_parser(args_parser) | 280 | args = run_parser(args_parser, default_args) |
237 | output_dir = Path(args.output_dir) | 281 | output_dir = Path(args.output_dir) |
238 | 282 | ||
239 | pipeline = create_pipeline(args.model, args.scheduler) | 283 | pipeline = create_pipeline(args.model, args.scheduler) |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index eff74b5..4c793a8 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -2,22 +2,29 @@ import inspect | |||
2 | import warnings | 2 | import warnings |
3 | from typing import List, Optional, Union | 3 | from typing import List, Optional, Union |
4 | 4 | ||
5 | import numpy as np | ||
5 | import torch | 6 | import torch |
6 | from torch import nn | 7 | import PIL |
7 | from torch.nn import functional as F | ||
8 | 8 | ||
9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict |
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from torchvision import transforms | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | ||
15 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward |
16 | 15 | ||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
18 | 17 | ||
19 | 18 | ||
20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 19 | def preprocess(image, w, h): |
20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
21 | image = np.array(image).astype(np.float32) / 255.0 | ||
22 | image = image[None].transpose(0, 3, 1, 2) | ||
23 | image = torch.from_numpy(image) | ||
24 | return 2.0 * image - 1.0 | ||
25 | |||
26 | |||
27 | class VlpnStableDiffusion(DiffusionPipeline): | ||
21 | def __init__( | 28 | def __init__( |
22 | self, | 29 | self, |
23 | vae: AutoencoderKL, | 30 | vae: AutoencoderKL, |
@@ -83,13 +90,14 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
83 | self, | 90 | self, |
84 | prompt: Union[str, List[str]], | 91 | prompt: Union[str, List[str]], |
85 | negative_prompt: Optional[Union[str, List[str]]] = None, | 92 | negative_prompt: Optional[Union[str, List[str]]] = None, |
93 | strength: float = 0.8, | ||
86 | height: Optional[int] = 512, | 94 | height: Optional[int] = 512, |
87 | width: Optional[int] = 512, | 95 | width: Optional[int] = 512, |
88 | num_inference_steps: Optional[int] = 50, | 96 | num_inference_steps: Optional[int] = 50, |
89 | guidance_scale: Optional[float] = 7.5, | 97 | guidance_scale: Optional[float] = 7.5, |
90 | eta: Optional[float] = 0.0, | 98 | eta: Optional[float] = 0.0, |
91 | generator: Optional[torch.Generator] = None, | 99 | generator: Optional[torch.Generator] = None, |
92 | latents: Optional[torch.FloatTensor] = None, | 100 | latents: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
93 | output_type: Optional[str] = "pil", | 101 | output_type: Optional[str] = "pil", |
94 | return_dict: bool = True, | 102 | return_dict: bool = True, |
95 | ): | 103 | ): |
@@ -99,6 +107,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
99 | Args: | 107 | Args: |
100 | prompt (`str` or `List[str]`): | 108 | prompt (`str` or `List[str]`): |
101 | The prompt or prompts to guide the image generation. | 109 | The prompt or prompts to guide the image generation. |
110 | strength (`float`, *optional*, defaults to 0.8): | ||
111 | Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. | ||
112 | `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The | ||
113 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | ||
114 | noise will be maximum and the denoising process will run for the full number of iterations specified in | ||
115 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | ||
102 | height (`int`, *optional*, defaults to 512): | 116 | height (`int`, *optional*, defaults to 512): |
103 | The height in pixels of the generated image. | 117 | The height in pixels of the generated image. |
104 | width (`int`, *optional*, defaults to 512): | 118 | width (`int`, *optional*, defaults to 512): |
@@ -158,6 +172,42 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
158 | if height % 8 != 0 or width % 8 != 0: | 172 | if height % 8 != 0 or width % 8 != 0: |
159 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
160 | 174 | ||
175 | if strength < 0 or strength > 1: | ||
176 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | ||
177 | |||
178 | # set timesteps | ||
179 | self.scheduler.set_timesteps(num_inference_steps) | ||
180 | |||
181 | offset = self.scheduler.config.get("steps_offset", 0) | ||
182 | |||
183 | if latents is not None and isinstance(latents, PIL.Image.Image): | ||
184 | latents = preprocess(latents, width, height) | ||
185 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | ||
186 | latents = latent_dist.sample(generator=generator) | ||
187 | latents = 0.18215 * latents | ||
188 | latents = torch.cat([latents] * batch_size) | ||
189 | |||
190 | # get the original timestep using init_timestep | ||
191 | init_timestep = int(num_inference_steps * strength) + offset | ||
192 | init_timestep = min(init_timestep, num_inference_steps) | ||
193 | |||
194 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
195 | timesteps = torch.tensor( | ||
196 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
197 | ) | ||
198 | elif isinstance(self.scheduler, EulerAScheduler): | ||
199 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
200 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
201 | else: | ||
202 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
203 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) | ||
204 | |||
205 | # add noise to latents using the timesteps | ||
206 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | ||
207 | latents = self.scheduler.add_noise(latents, noise, timesteps) | ||
208 | else: | ||
209 | init_timestep = num_inference_steps + offset | ||
210 | |||
161 | # get prompt text embeddings | 211 | # get prompt text embeddings |
162 | text_inputs = self.tokenizer( | 212 | text_inputs = self.tokenizer( |
163 | prompt, | 213 | prompt, |
@@ -213,15 +263,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 263 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
214 | latents = latents.to(self.device) | 264 | latents = latents.to(self.device) |
215 | 265 | ||
216 | # set timesteps | 266 | t_start = max(num_inference_steps - init_timestep + offset, 0) |
217 | self.scheduler.set_timesteps(num_inference_steps) | ||
218 | 267 | ||
219 | # Some schedulers like PNDM have timesteps as arrays | 268 | # Some schedulers like PNDM have timesteps as arrays |
220 | # It's more optimzed to move all timesteps to correct device beforehand | 269 | # It's more optimzed to move all timesteps to correct device beforehand |
221 | if torch.is_tensor(self.scheduler.timesteps): | 270 | timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) |
222 | timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
223 | else: | ||
224 | timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
225 | 271 | ||
226 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | 272 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas |
227 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 273 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
@@ -244,10 +290,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
244 | extra_step_kwargs["generator"] = generator | 290 | extra_step_kwargs["generator"] = generator |
245 | 291 | ||
246 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 292 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): |
293 | t_index = t_start + i | ||
294 | |||
247 | # expand the latents if we are doing classifier free guidance | 295 | # expand the latents if we are doing classifier free guidance |
248 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 296 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
249 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 297 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
250 | sigma = self.scheduler.sigmas[i] | 298 | sigma = self.scheduler.sigmas[t_index] |
251 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | 299 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS |
252 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | 300 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) |
253 | 301 | ||
@@ -270,10 +318,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
270 | 318 | ||
271 | # compute the previous noisy sample x_t -> x_t-1 | 319 | # compute the previous noisy sample x_t -> x_t-1 |
272 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 320 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 321 | latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample |
274 | elif isinstance(self.scheduler, EulerAScheduler): | 322 | elif isinstance(self.scheduler, EulerAScheduler): |
275 | if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | 323 | if t_index < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error |
276 | t_prev = self.scheduler.timesteps[i+1] | 324 | t_prev = self.scheduler.timesteps[t_index+1] |
277 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | 325 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample |
278 | else: | 326 | else: |
279 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 327 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |