diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | dreambooth.py | 12 | ||||
-rw-r--r-- | infer.py | 111 | ||||
-rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 457 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 323 |
5 files changed, 869 insertions, 36 deletions
@@ -160,5 +160,5 @@ cython_debug/ | |||
160 | #.idea/ | 160 | #.idea/ |
161 | 161 | ||
162 | output/ | 162 | output/ |
163 | conf*.json | 163 | conf/ |
164 | v1-inference.yaml* | 164 | v1-inference.yaml* |
diff --git a/dreambooth.py b/dreambooth.py index 39c4851..4d7366c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -59,7 +59,7 @@ def parse_args(): | |||
59 | parser.add_argument( | 59 | parser.add_argument( |
60 | "--repeats", | 60 | "--repeats", |
61 | type=int, | 61 | type=int, |
62 | default=100, | 62 | default=1, |
63 | help="How many times to repeat the training data." | 63 | help="How many times to repeat the training data." |
64 | ) | 64 | ) |
65 | parser.add_argument( | 65 | parser.add_argument( |
@@ -375,7 +375,6 @@ class Checkpointer: | |||
375 | @torch.no_grad() | 375 | @torch.no_grad() |
376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
377 | samples_path = Path(self.output_dir).joinpath("samples") | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
378 | samples_path.mkdir(parents=True, exist_ok=True) | ||
379 | 378 | ||
380 | unwrapped = self.accelerator.unwrap_model(self.unet) | 379 | unwrapped = self.accelerator.unwrap_model(self.unet) |
381 | pipeline = StableDiffusionPipeline( | 380 | pipeline = StableDiffusionPipeline( |
@@ -403,6 +402,7 @@ class Checkpointer: | |||
403 | 402 | ||
404 | all_samples = [] | 403 | all_samples = [] |
405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | 404 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
405 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
406 | 406 | ||
407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
408 | 408 | ||
@@ -436,6 +436,7 @@ class Checkpointer: | |||
436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
437 | all_samples = [] | 437 | all_samples = [] |
438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
439 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
439 | 440 | ||
440 | data_enum = enumerate(data) | 441 | data_enum = enumerate(data) |
441 | 442 | ||
@@ -496,11 +497,15 @@ def main(): | |||
496 | cur_class_images = len(list(class_images_dir.iterdir())) | 497 | cur_class_images = len(list(class_images_dir.iterdir())) |
497 | 498 | ||
498 | if cur_class_images < args.num_class_images: | 499 | if cur_class_images < args.num_class_images: |
499 | torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 | 500 | torch_dtype = torch.float32 |
501 | if accelerator.device.type == "cuda": | ||
502 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
503 | |||
500 | pipeline = StableDiffusionPipeline.from_pretrained( | 504 | pipeline = StableDiffusionPipeline.from_pretrained( |
501 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 505 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
502 | pipeline.enable_attention_slicing() | 506 | pipeline.enable_attention_slicing() |
503 | pipeline.set_progress_bar_config(disable=True) | 507 | pipeline.set_progress_bar_config(disable=True) |
508 | pipeline.to(accelerator.device) | ||
504 | 509 | ||
505 | num_new_images = args.num_class_images - cur_class_images | 510 | num_new_images = args.num_class_images - cur_class_images |
506 | logger.info(f"Number of class images to sample: {num_new_images}.") | 511 | logger.info(f"Number of class images to sample: {num_new_images}.") |
@@ -509,7 +514,6 @@ def main(): | |||
509 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | 514 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) |
510 | 515 | ||
511 | sample_dataloader = accelerator.prepare(sample_dataloader) | 516 | sample_dataloader = accelerator.prepare(sample_dataloader) |
512 | pipeline.to(accelerator.device) | ||
513 | 517 | ||
514 | for example in tqdm( | 518 | for example in tqdm( |
515 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | 519 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process |
@@ -1,18 +1,15 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | ||
3 | from pathlib import Path | 4 | from pathlib import Path |
4 | from torch import autocast | 5 | from torch import autocast |
5 | from diffusers import StableDiffusionPipeline | ||
6 | import torch | 6 | import torch |
7 | import json | 7 | import json |
8 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler | 8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
9 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
10 | from slugify import slugify | 10 | from slugify import slugify |
11 | from pipelines.stable_diffusion.no_check import NoCheck | 11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
12 | 12 | from schedulers.scheduling_euler_a import EulerAScheduler | |
13 | model_id = "path-to-your-trained-model" | ||
14 | |||
15 | prompt = "A photo of sks dog in a bucket" | ||
16 | 13 | ||
17 | 14 | ||
18 | def parse_args(): | 15 | def parse_args(): |
@@ -30,6 +27,21 @@ def parse_args(): | |||
30 | default=None, | 27 | default=None, |
31 | ) | 28 | ) |
32 | parser.add_argument( | 29 | parser.add_argument( |
30 | "--negative_prompt", | ||
31 | type=str, | ||
32 | default=None, | ||
33 | ) | ||
34 | parser.add_argument( | ||
35 | "--width", | ||
36 | type=int, | ||
37 | default=512, | ||
38 | ) | ||
39 | parser.add_argument( | ||
40 | "--height", | ||
41 | type=int, | ||
42 | default=512, | ||
43 | ) | ||
44 | parser.add_argument( | ||
33 | "--batch_size", | 45 | "--batch_size", |
34 | type=int, | 46 | type=int, |
35 | default=1, | 47 | default=1, |
@@ -42,17 +54,28 @@ def parse_args(): | |||
42 | parser.add_argument( | 54 | parser.add_argument( |
43 | "--steps", | 55 | "--steps", |
44 | type=int, | 56 | type=int, |
45 | default=80, | 57 | default=120, |
46 | ) | 58 | ) |
47 | parser.add_argument( | 59 | parser.add_argument( |
48 | "--scale", | 60 | "--scheduler", |
61 | type=str, | ||
62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
63 | default="euler_a", | ||
64 | ) | ||
65 | parser.add_argument( | ||
66 | "--guidance_scale", | ||
49 | type=int, | 67 | type=int, |
50 | default=7.5, | 68 | default=7.5, |
51 | ) | 69 | ) |
52 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--clip_guidance_scale", | ||
72 | type=int, | ||
73 | default=100, | ||
74 | ) | ||
75 | parser.add_argument( | ||
53 | "--seed", | 76 | "--seed", |
54 | type=int, | 77 | type=int, |
55 | default=None, | 78 | default=torch.random.seed(), |
56 | ) | 79 | ) |
57 | parser.add_argument( | 80 | parser.add_argument( |
58 | "--output_dir", | 81 | "--output_dir", |
@@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}): | |||
81 | json.dump(info, f, indent=4) | 104 | json.dump(info, f, indent=4) |
82 | 105 | ||
83 | 106 | ||
84 | def main(): | 107 | def gen(args, output_dir): |
85 | args = parse_args() | ||
86 | |||
87 | seed = args.seed or torch.random.seed() | ||
88 | |||
89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
90 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
91 | output_dir.mkdir(parents=True, exist_ok=True) | ||
92 | save_args(output_dir, args) | ||
93 | |||
94 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) |
95 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | 109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) |
110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
96 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | 111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) |
97 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | 112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) |
98 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) | 113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( |
114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
99 | 115 | ||
100 | pipeline = StableDiffusionPipeline( | 116 | if args.scheduler == "plms": |
117 | scheduler = PNDMScheduler( | ||
118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
119 | ) | ||
120 | elif args.scheduler == "klms": | ||
121 | scheduler = LMSDiscreteScheduler( | ||
122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
123 | ) | ||
124 | elif args.scheduler == "ddim": | ||
125 | scheduler = DDIMScheduler( | ||
126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
127 | ) | ||
128 | else: | ||
129 | scheduler = EulerAScheduler( | ||
130 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | ||
131 | ) | ||
132 | |||
133 | pipeline = CLIPGuidedStableDiffusion( | ||
101 | text_encoder=text_encoder, | 134 | text_encoder=text_encoder, |
102 | vae=vae, | 135 | vae=vae, |
103 | unet=unet, | 136 | unet=unet, |
104 | tokenizer=tokenizer, | 137 | tokenizer=tokenizer, |
105 | scheduler=PNDMScheduler( | 138 | clip_model=clip_model, |
106 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 139 | scheduler=scheduler, |
107 | ), | ||
108 | safety_checker=NoCheck(), | ||
109 | feature_extractor=feature_extractor | 140 | feature_extractor=feature_extractor |
110 | ) | 141 | ) |
111 | pipeline.enable_attention_slicing() | 142 | pipeline.enable_attention_slicing() |
@@ -113,16 +144,34 @@ def main(): | |||
113 | 144 | ||
114 | with autocast("cuda"): | 145 | with autocast("cuda"): |
115 | for i in range(args.batch_num): | 146 | for i in range(args.batch_num): |
116 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
117 | images = pipeline( | 148 | images = pipeline( |
118 | [args.prompt] * args.batch_size, | 149 | prompt=[args.prompt] * args.batch_size, |
150 | height=args.height, | ||
151 | width=args.width, | ||
152 | negative_prompt=args.negative_prompt, | ||
119 | num_inference_steps=args.steps, | 153 | num_inference_steps=args.steps, |
120 | guidance_scale=args.scale, | 154 | guidance_scale=args.guidance_scale, |
155 | clip_guidance_scale=args.clip_guidance_scale, | ||
121 | generator=generator, | 156 | generator=generator, |
122 | ).images | 157 | ).images |
123 | 158 | ||
124 | for j, image in enumerate(images): | 159 | for j, image in enumerate(images): |
125 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
161 | |||
162 | |||
163 | def main(): | ||
164 | args = parse_args() | ||
165 | |||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | |||
170 | save_args(output_dir, args) | ||
171 | |||
172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
173 | |||
174 | gen(args, output_dir) | ||
126 | 175 | ||
127 | 176 | ||
128 | if __name__ == "__main__": | 177 | if __name__ == "__main__": |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py new file mode 100644 index 0000000..306d9a9 --- /dev/null +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
@@ -0,0 +1,457 @@ | |||
1 | import inspect | ||
2 | import warnings | ||
3 | from typing import List, Optional, Union | ||
4 | |||
5 | import torch | ||
6 | from torch import nn | ||
7 | from torch.nn import functional as F | ||
8 | |||
9 | from diffusers.configuration_utils import FrozenDict | ||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | ||
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | ||
12 | from diffusers.utils import logging | ||
13 | from torchvision import transforms | ||
14 | from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | ||
15 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | ||
16 | |||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
18 | |||
19 | |||
20 | class MakeCutouts(nn.Module): | ||
21 | def __init__(self, cut_size, cut_power=1.0): | ||
22 | super().__init__() | ||
23 | |||
24 | self.cut_size = cut_size | ||
25 | self.cut_power = cut_power | ||
26 | |||
27 | def forward(self, pixel_values, num_cutouts): | ||
28 | sideY, sideX = pixel_values.shape[2:4] | ||
29 | max_size = min(sideX, sideY) | ||
30 | min_size = min(sideX, sideY, self.cut_size) | ||
31 | cutouts = [] | ||
32 | for _ in range(num_cutouts): | ||
33 | size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | ||
34 | offsetx = torch.randint(0, sideX - size + 1, ()) | ||
35 | offsety = torch.randint(0, sideY - size + 1, ()) | ||
36 | cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] | ||
37 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | ||
38 | return torch.cat(cutouts) | ||
39 | |||
40 | |||
41 | def spherical_dist_loss(x, y): | ||
42 | x = F.normalize(x, dim=-1) | ||
43 | y = F.normalize(y, dim=-1) | ||
44 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | ||
45 | |||
46 | |||
47 | def set_requires_grad(model, value): | ||
48 | for param in model.parameters(): | ||
49 | param.requires_grad = value | ||
50 | |||
51 | |||
52 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | ||
53 | """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | ||
54 | - https://github.com/Jack000/glid-3-xl | ||
55 | - https://github.dev/crowsonkb/k-diffusion | ||
56 | """ | ||
57 | |||
58 | def __init__( | ||
59 | self, | ||
60 | vae: AutoencoderKL, | ||
61 | text_encoder: CLIPTextModel, | ||
62 | clip_model: CLIPModel, | ||
63 | tokenizer: CLIPTokenizer, | ||
64 | unet: UNet2DConditionModel, | ||
65 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | ||
66 | feature_extractor: CLIPFeatureExtractor, | ||
67 | **kwargs, | ||
68 | ): | ||
69 | super().__init__() | ||
70 | |||
71 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | ||
72 | warnings.warn( | ||
73 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | ||
74 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | ||
75 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | ||
76 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | ||
77 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | ||
78 | " file", | ||
79 | DeprecationWarning, | ||
80 | ) | ||
81 | new_config = dict(scheduler.config) | ||
82 | new_config["steps_offset"] = 1 | ||
83 | scheduler._internal_dict = FrozenDict(new_config) | ||
84 | |||
85 | self.register_modules( | ||
86 | vae=vae, | ||
87 | text_encoder=text_encoder, | ||
88 | clip_model=clip_model, | ||
89 | tokenizer=tokenizer, | ||
90 | unet=unet, | ||
91 | scheduler=scheduler, | ||
92 | feature_extractor=feature_extractor, | ||
93 | ) | ||
94 | |||
95 | self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | ||
96 | self.make_cutouts = MakeCutouts(feature_extractor.size) | ||
97 | |||
98 | set_requires_grad(self.text_encoder, False) | ||
99 | set_requires_grad(self.clip_model, False) | ||
100 | |||
101 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | ||
102 | r""" | ||
103 | Enable sliced attention computation. | ||
104 | |||
105 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | ||
106 | in several steps. This is useful to save some memory in exchange for a small speed decrease. | ||
107 | |||
108 | Args: | ||
109 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | ||
110 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | ||
111 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, | ||
112 | `attention_head_dim` must be a multiple of `slice_size`. | ||
113 | """ | ||
114 | if slice_size == "auto": | ||
115 | # half the attention head size is usually a good trade-off between | ||
116 | # speed and memory | ||
117 | slice_size = self.unet.config.attention_head_dim // 2 | ||
118 | self.unet.set_attention_slice(slice_size) | ||
119 | |||
120 | def disable_attention_slicing(self): | ||
121 | r""" | ||
122 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go | ||
123 | back to computing attention in one step. | ||
124 | """ | ||
125 | # set slice_size = `None` to disable `attention slicing` | ||
126 | self.enable_attention_slicing(None) | ||
127 | |||
128 | def freeze_vae(self): | ||
129 | set_requires_grad(self.vae, False) | ||
130 | |||
131 | def unfreeze_vae(self): | ||
132 | set_requires_grad(self.vae, True) | ||
133 | |||
134 | def freeze_unet(self): | ||
135 | set_requires_grad(self.unet, False) | ||
136 | |||
137 | def unfreeze_unet(self): | ||
138 | set_requires_grad(self.unet, True) | ||
139 | |||
140 | @torch.enable_grad() | ||
141 | def cond_fn( | ||
142 | self, | ||
143 | latents, | ||
144 | timestep, | ||
145 | index, | ||
146 | text_embeddings, | ||
147 | noise_pred_original, | ||
148 | text_embeddings_clip, | ||
149 | clip_guidance_scale, | ||
150 | num_cutouts, | ||
151 | use_cutouts=True, | ||
152 | ): | ||
153 | latents = latents.detach().requires_grad_() | ||
154 | |||
155 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
156 | sigma = self.scheduler.sigmas[index] | ||
157 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
158 | latent_model_input = latents / ((sigma**2 + 1) ** 0.5) | ||
159 | else: | ||
160 | latent_model_input = latents | ||
161 | |||
162 | # predict the noise residual | ||
163 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | ||
164 | |||
165 | if isinstance(self.scheduler, PNDMScheduler): | ||
166 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
167 | beta_prod_t = 1 - alpha_prod_t | ||
168 | # compute predicted original sample from predicted noise also called | ||
169 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
170 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | ||
171 | |||
172 | fac = torch.sqrt(beta_prod_t) | ||
173 | sample = pred_original_sample * (fac) + latents * (1 - fac) | ||
174 | elif isinstance(self.scheduler, LMSDiscreteScheduler): | ||
175 | sigma = self.scheduler.sigmas[index] | ||
176 | sample = latents - sigma * noise_pred | ||
177 | else: | ||
178 | raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | ||
179 | |||
180 | sample = 1 / 0.18215 * sample | ||
181 | image = self.vae.decode(sample).sample | ||
182 | image = (image / 2 + 0.5).clamp(0, 1) | ||
183 | |||
184 | if use_cutouts: | ||
185 | image = self.make_cutouts(image, num_cutouts) | ||
186 | else: | ||
187 | image = transforms.Resize(self.feature_extractor.size)(image) | ||
188 | image = self.normalize(image) | ||
189 | |||
190 | image_embeddings_clip = self.clip_model.get_image_features(image).float() | ||
191 | image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
192 | |||
193 | if use_cutouts: | ||
194 | dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | ||
195 | dists = dists.view([num_cutouts, sample.shape[0], -1]) | ||
196 | loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | ||
197 | else: | ||
198 | loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | ||
199 | |||
200 | grads = -torch.autograd.grad(loss, latents)[0] | ||
201 | |||
202 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
203 | latents = latents.detach() + grads * (sigma**2) | ||
204 | noise_pred = noise_pred_original | ||
205 | else: | ||
206 | noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | ||
207 | return noise_pred, latents | ||
208 | |||
209 | @torch.no_grad() | ||
210 | def __call__( | ||
211 | self, | ||
212 | prompt: Union[str, List[str]], | ||
213 | negative_prompt: Optional[Union[str, List[str]]] = None, | ||
214 | height: Optional[int] = 512, | ||
215 | width: Optional[int] = 512, | ||
216 | num_inference_steps: Optional[int] = 50, | ||
217 | guidance_scale: Optional[float] = 7.5, | ||
218 | eta: Optional[float] = 0.0, | ||
219 | clip_guidance_scale: Optional[float] = 100, | ||
220 | clip_prompt: Optional[Union[str, List[str]]] = None, | ||
221 | num_cutouts: Optional[int] = 4, | ||
222 | use_cutouts: Optional[bool] = True, | ||
223 | generator: Optional[torch.Generator] = None, | ||
224 | latents: Optional[torch.FloatTensor] = None, | ||
225 | output_type: Optional[str] = "pil", | ||
226 | return_dict: bool = True, | ||
227 | ): | ||
228 | r""" | ||
229 | Function invoked when calling the pipeline for generation. | ||
230 | |||
231 | Args: | ||
232 | prompt (`str` or `List[str]`): | ||
233 | The prompt or prompts to guide the image generation. | ||
234 | height (`int`, *optional*, defaults to 512): | ||
235 | The height in pixels of the generated image. | ||
236 | width (`int`, *optional*, defaults to 512): | ||
237 | The width in pixels of the generated image. | ||
238 | num_inference_steps (`int`, *optional*, defaults to 50): | ||
239 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
240 | expense of slower inference. | ||
241 | guidance_scale (`float`, *optional*, defaults to 7.5): | ||
242 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
243 | `guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
244 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
245 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
246 | usually at the expense of lower image quality. | ||
247 | eta (`float`, *optional*, defaults to 0.0): | ||
248 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||
249 | [`schedulers.DDIMScheduler`], will be ignored for others. | ||
250 | generator (`torch.Generator`, *optional*): | ||
251 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
252 | deterministic. | ||
253 | latents (`torch.FloatTensor`, *optional*): | ||
254 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
255 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
256 | tensor will ge generated by sampling using the supplied random `generator`. | ||
257 | output_type (`str`, *optional*, defaults to `"pil"`): | ||
258 | The output format of the generate image. Choose between | ||
259 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
260 | return_dict (`bool`, *optional*, defaults to `True`): | ||
261 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||
262 | plain tuple. | ||
263 | |||
264 | Returns: | ||
265 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | ||
266 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | ||
267 | When returning a tuple, the first element is a list with the generated images, and the second element is a | ||
268 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | ||
269 | (nsfw) content, according to the `safety_checker`. | ||
270 | """ | ||
271 | |||
272 | if isinstance(prompt, str): | ||
273 | batch_size = 1 | ||
274 | elif isinstance(prompt, list): | ||
275 | batch_size = len(prompt) | ||
276 | else: | ||
277 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
278 | |||
279 | if negative_prompt is None: | ||
280 | negative_prompt = [""] * batch_size | ||
281 | elif isinstance(negative_prompt, str): | ||
282 | negative_prompt = [negative_prompt] * batch_size | ||
283 | elif isinstance(negative_prompt, list): | ||
284 | if len(negative_prompt) != batch_size: | ||
285 | raise ValueError( | ||
286 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
287 | else: | ||
288 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
289 | |||
290 | if height % 8 != 0 or width % 8 != 0: | ||
291 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
292 | |||
293 | # get prompt text embeddings | ||
294 | text_inputs = self.tokenizer( | ||
295 | prompt, | ||
296 | padding="max_length", | ||
297 | max_length=self.tokenizer.model_max_length, | ||
298 | return_tensors="pt", | ||
299 | ) | ||
300 | text_input_ids = text_inputs.input_ids | ||
301 | |||
302 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
303 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) | ||
304 | logger.warning( | ||
305 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
306 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
307 | ) | ||
308 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | ||
309 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | ||
310 | |||
311 | if clip_guidance_scale > 0: | ||
312 | if clip_prompt is not None: | ||
313 | clip_text_inputs = self.tokenizer( | ||
314 | clip_prompt, | ||
315 | padding="max_length", | ||
316 | max_length=self.tokenizer.model_max_length, | ||
317 | truncation=True, | ||
318 | return_tensors="pt", | ||
319 | ) | ||
320 | clip_text_input_ids = clip_text_inputs.input_ids | ||
321 | else: | ||
322 | clip_text_input_ids = text_input_ids | ||
323 | text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) | ||
324 | text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
325 | |||
326 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
327 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
328 | # corresponds to doing no classifier free guidance. | ||
329 | do_classifier_free_guidance = guidance_scale > 1.0 | ||
330 | # get unconditional embeddings for classifier free guidance | ||
331 | if do_classifier_free_guidance: | ||
332 | max_length = text_input_ids.shape[-1] | ||
333 | uncond_input = self.tokenizer( | ||
334 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | ||
335 | ) | ||
336 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
337 | |||
338 | # For classifier free guidance, we need to do two forward passes. | ||
339 | # Here we concatenate the unconditional and text embeddings into a single batch | ||
340 | # to avoid doing two forward passes | ||
341 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
342 | |||
343 | # get the initial random noise unless the user supplied it | ||
344 | |||
345 | # Unlike in other pipelines, latents need to be generated in the target device | ||
346 | # for 1-to-1 results reproducibility with the CompVis implementation. | ||
347 | # However this currently doesn't work in `mps`. | ||
348 | latents_device = "cpu" if self.device.type == "mps" else self.device | ||
349 | latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) | ||
350 | if latents is None: | ||
351 | latents = torch.randn( | ||
352 | latents_shape, | ||
353 | generator=generator, | ||
354 | device=latents_device, | ||
355 | dtype=text_embeddings.dtype, | ||
356 | ) | ||
357 | else: | ||
358 | if latents.shape != latents_shape: | ||
359 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | ||
360 | latents = latents.to(self.device) | ||
361 | |||
362 | # set timesteps | ||
363 | self.scheduler.set_timesteps(num_inference_steps) | ||
364 | |||
365 | # Some schedulers like PNDM have timesteps as arrays | ||
366 | # It's more optimzed to move all timesteps to correct device beforehand | ||
367 | if torch.is_tensor(self.scheduler.timesteps): | ||
368 | timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
369 | else: | ||
370 | timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) | ||
371 | |||
372 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas | ||
373 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
374 | latents = latents * self.scheduler.sigmas[0] | ||
375 | elif isinstance(self.scheduler, EulerAScheduler): | ||
376 | sigma = self.scheduler.timesteps[0] | ||
377 | latents = latents * sigma | ||
378 | |||
379 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
380 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
381 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
382 | # and should be between [0, 1] | ||
383 | scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
384 | accepts_eta = "eta" in scheduler_step_args | ||
385 | extra_step_kwargs = {} | ||
386 | if accepts_eta: | ||
387 | extra_step_kwargs["eta"] = eta | ||
388 | accepts_generator = "generator" in scheduler_step_args | ||
389 | if generator is not None and accepts_generator: | ||
390 | extra_step_kwargs["generator"] = generator | ||
391 | |||
392 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | ||
393 | # expand the latents if we are doing classifier free guidance | ||
394 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
395 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
396 | sigma = self.scheduler.sigmas[i] | ||
397 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
398 | latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | ||
399 | |||
400 | noise_pred = None | ||
401 | if isinstance(self.scheduler, EulerAScheduler): | ||
402 | sigma = t.reshape(1) | ||
403 | sigma_in = torch.cat([sigma] * 2) | ||
404 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | ||
405 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | ||
406 | text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) | ||
407 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | ||
408 | else: | ||
409 | # predict the noise residual | ||
410 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | ||
411 | |||
412 | # perform guidance | ||
413 | if do_classifier_free_guidance: | ||
414 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
415 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
416 | |||
417 | # perform clip guidance | ||
418 | if clip_guidance_scale > 0: | ||
419 | text_embeddings_for_guidance = ( | ||
420 | text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | ||
421 | ) | ||
422 | noise_pred, latents = self.cond_fn( | ||
423 | latents, | ||
424 | t, | ||
425 | i, | ||
426 | text_embeddings_for_guidance, | ||
427 | noise_pred, | ||
428 | text_embeddings_clip, | ||
429 | clip_guidance_scale, | ||
430 | num_cutouts, | ||
431 | use_cutouts, | ||
432 | ) | ||
433 | |||
434 | # compute the previous noisy sample x_t -> x_t-1 | ||
435 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
436 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | ||
437 | elif isinstance(self.scheduler, EulerAScheduler): | ||
438 | if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error | ||
439 | t_prev = self.scheduler.timesteps[i+1] | ||
440 | latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample | ||
441 | else: | ||
442 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | ||
443 | |||
444 | # scale and decode the image latents with vae | ||
445 | latents = 1 / 0.18215 * latents | ||
446 | image = self.vae.decode(latents).sample | ||
447 | |||
448 | image = (image / 2 + 0.5).clamp(0, 1) | ||
449 | image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
450 | |||
451 | if output_type == "pil": | ||
452 | image = self.numpy_to_pil(image) | ||
453 | |||
454 | if not return_dict: | ||
455 | return (image, None) | ||
456 | |||
457 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | ||
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py new file mode 100644 index 0000000..57a56de --- /dev/null +++ b/schedulers/scheduling_euler_a.py | |||
@@ -0,0 +1,323 @@ | |||
1 | |||
2 | |||
3 | import math | ||
4 | import warnings | ||
5 | from typing import Optional, Tuple, Union | ||
6 | |||
7 | import numpy as np | ||
8 | import torch | ||
9 | |||
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
11 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | ||
12 | |||
13 | |||
14 | ''' | ||
15 | helper functions: append_zero(), | ||
16 | t_to_sigma(), | ||
17 | get_sigmas(), | ||
18 | append_dims(), | ||
19 | CFGDenoiserForward(), | ||
20 | get_scalings(), | ||
21 | DSsigma_to_t(), | ||
22 | DiscreteEpsDDPMDenoiserForward(), | ||
23 | to_d(), | ||
24 | get_ancestral_step() | ||
25 | need cleaning | ||
26 | ''' | ||
27 | |||
28 | |||
29 | def append_zero(x): | ||
30 | return torch.cat([x, x.new_zeros([1])]) | ||
31 | |||
32 | |||
33 | def t_to_sigma(t, sigmas): | ||
34 | t = t.float() | ||
35 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
36 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
37 | |||
38 | |||
39 | def get_sigmas(sigmas, n=None): | ||
40 | if n is None: | ||
41 | return append_zero(sigmas.flip(0)) | ||
42 | t_max = len(sigmas) - 1 # = 999 | ||
43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
45 | return append_zero(t_to_sigma(t, sigmas)) | ||
46 | |||
47 | # from k_samplers utils.py | ||
48 | |||
49 | |||
50 | def append_dims(x, target_dims): | ||
51 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
52 | dims_to_append = target_dims - x.ndim | ||
53 | if dims_to_append < 0: | ||
54 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
55 | return x[(...,) + (None,) * dims_to_append] | ||
56 | |||
57 | |||
58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | ||
59 | # x_in = torch.cat([x] * 2)#A# concat the latent | ||
60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | ||
61 | # cond_in = torch.cat([uncond, cond]) | ||
62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | ||
63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | ||
64 | # return uncond + (cond - uncond) * cond_scale | ||
65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | ||
66 | return noise_pred | ||
67 | |||
68 | # from k_samplers sampling.py | ||
69 | |||
70 | |||
71 | def to_d(x, sigma, denoised): | ||
72 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
73 | return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) | ||
74 | |||
75 | |||
76 | def get_scalings(sigma): | ||
77 | sigma_data = 1. | ||
78 | c_out = -sigma | ||
79 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
80 | return c_out, c_in | ||
81 | |||
82 | # DiscreteSchedule DS | ||
83 | |||
84 | |||
85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | ||
86 | # quantize = self.quantize if quantize is None else quantize | ||
87 | quantize = False | ||
88 | dists = torch.abs(sigma - DSsigmas[:, None]) | ||
89 | if quantize: | ||
90 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
91 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
92 | low, high = DSsigmas[low_idx], DSsigmas[high_idx] | ||
93 | w = (low - sigma) / (low - high) | ||
94 | w = w.clamp(0, 1) | ||
95 | t = (1 - w) * low_idx + w * high_idx | ||
96 | return t.view(sigma.shape) | ||
97 | |||
98 | |||
99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | ||
100 | sigma = sigma.to(Unet.device) | ||
101 | DSsigmas = DSsigmas.to(Unet.device) | ||
102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | ||
103 | # ??? what is eps? | ||
104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
106 | encoder_hidden_states=kwargs['cond']).sample | ||
107 | return input + eps * c_out | ||
108 | |||
109 | |||
110 | # from k_samplers sampling.py | ||
111 | def get_ancestral_step(sigma_from, sigma_to): | ||
112 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
113 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
114 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
115 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
116 | return sigma_down, sigma_up | ||
117 | |||
118 | |||
119 | ''' | ||
120 | Euler Ancestral Scheduler | ||
121 | ''' | ||
122 | |||
123 | |||
124 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | ||
125 | """ | ||
126 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | ||
127 | the VE column of Table 1 from [1] for reference. | ||
128 | |||
129 | [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." | ||
130 | https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic | ||
131 | differential equations." https://arxiv.org/abs/2011.13456 | ||
132 | |||
133 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
134 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
135 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
136 | [`~ConfigMixin.from_config`] functions. | ||
137 | |||
138 | For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of | ||
139 | Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the | ||
140 | optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. | ||
141 | |||
142 | Args: | ||
143 | sigma_min (`float`): minimum noise magnitude | ||
144 | sigma_max (`float`): maximum noise magnitude | ||
145 | s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. | ||
146 | A reasonable range is [1.000, 1.011]. | ||
147 | s_churn (`float`): the parameter controlling the overall amount of stochasticity. | ||
148 | A reasonable range is [0, 100]. | ||
149 | s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). | ||
150 | A reasonable range is [0, 10]. | ||
151 | s_max (`float`): the end value of the sigma range where we add noise. | ||
152 | A reasonable range is [0.2, 80]. | ||
153 | |||
154 | """ | ||
155 | |||
156 | @register_to_config | ||
157 | def __init__( | ||
158 | self, | ||
159 | num_train_timesteps: int = 1000, | ||
160 | beta_start: float = 0.0001, | ||
161 | beta_end: float = 0.02, | ||
162 | beta_schedule: str = "linear", | ||
163 | trained_betas: Optional[np.ndarray] = None, | ||
164 | clip_sample: bool = True, | ||
165 | set_alpha_to_one: bool = True, | ||
166 | steps_offset: int = 0, | ||
167 | ): | ||
168 | if trained_betas is not None: | ||
169 | self.betas = torch.from_numpy(trained_betas) | ||
170 | if beta_schedule == "linear": | ||
171 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
172 | elif beta_schedule == "scaled_linear": | ||
173 | # this schedule is very specific to the latent diffusion model. | ||
174 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
175 | elif beta_schedule == "squaredcos_cap_v2": | ||
176 | # Glide cosine schedule | ||
177 | self.betas = betas_for_alpha_bar(num_train_timesteps) | ||
178 | else: | ||
179 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
180 | |||
181 | self.alphas = 1.0 - self.betas | ||
182 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
183 | |||
184 | # At every step in ddim, we are looking into the previous alphas_cumprod | ||
185 | # For the final step, there is no previous alphas_cumprod because we are already at 0 | ||
186 | # `set_alpha_to_one` decides whether we set this parameter simply to one or | ||
187 | # whether we use the final alpha of the "non-previous" one. | ||
188 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | ||
189 | |||
190 | # setable values | ||
191 | self.num_inference_steps = None | ||
192 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | ||
193 | |||
194 | # A# take number of steps as input | ||
195 | # A# store 1) number of steps 2) timesteps 3) schedule | ||
196 | |||
197 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): | ||
198 | """ | ||
199 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
200 | |||
201 | Args: | ||
202 | num_inference_steps (`int`): | ||
203 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
204 | """ | ||
205 | |||
206 | # offset = self.config.steps_offset | ||
207 | |||
208 | # if "offset" in kwargs: | ||
209 | # warnings.warn( | ||
210 | # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
211 | # " Please pass `steps_offset` to `__init__` instead.", | ||
212 | # DeprecationWarning, | ||
213 | # ) | ||
214 | |||
215 | # offset = kwargs["offset"] | ||
216 | |||
217 | self.num_inference_steps = num_inference_steps | ||
218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | ||
220 | self.timesteps = self.sigmas | ||
221 | |||
222 | def add_noise_to_input( | ||
223 | self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None | ||
224 | ) -> Tuple[torch.FloatTensor, float]: | ||
225 | """ | ||
226 | Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a | ||
227 | higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. | ||
228 | |||
229 | TODO Args: | ||
230 | """ | ||
231 | if self.config.s_min <= sigma <= self.config.s_max: | ||
232 | gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) | ||
233 | else: | ||
234 | gamma = 0 | ||
235 | |||
236 | # sample eps ~ N(0, S_noise^2 * I) | ||
237 | eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) | ||
238 | sigma_hat = sigma + gamma * sigma | ||
239 | sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) | ||
240 | |||
241 | return sample_hat, sigma_hat | ||
242 | |||
243 | def step( | ||
244 | self, | ||
245 | model_output: torch.FloatTensor, | ||
246 | timestep: torch.IntTensor, | ||
247 | timestep_prev: torch.IntTensor, | ||
248 | sample: torch.FloatTensor, | ||
249 | generator: None, | ||
250 | # ,sigma_hat: float, | ||
251 | # sigma_prev: float, | ||
252 | # sample_hat: torch.FloatTensor, | ||
253 | return_dict: bool = True, | ||
254 | ) -> Union[SchedulerOutput, Tuple]: | ||
255 | """ | ||
256 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
257 | process from the learned model outputs (most often the predicted noise). | ||
258 | |||
259 | Args: | ||
260 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
261 | sigma_hat (`float`): TODO | ||
262 | sigma_prev (`float`): TODO | ||
263 | sample_hat (`torch.FloatTensor`): TODO | ||
264 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
265 | |||
266 | EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check). | ||
267 | Returns: | ||
268 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`: | ||
269 | [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When | ||
270 | returning a tuple, the first element is the sample tensor. | ||
271 | |||
272 | """ | ||
273 | latents = sample | ||
274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | ||
275 | |||
276 | # if callback is not None: | ||
277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
278 | d = to_d(latents, timestep, model_output) | ||
279 | # Euler method | ||
280 | dt = sigma_down - timestep | ||
281 | latents = latents + d * dt | ||
282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | ||
283 | generator=generator) * sigma_up | ||
284 | return SchedulerOutput(prev_sample=latents) | ||
285 | |||
286 | def step_correct( | ||
287 | self, | ||
288 | model_output: torch.FloatTensor, | ||
289 | sigma_hat: float, | ||
290 | sigma_prev: float, | ||
291 | sample_hat: torch.FloatTensor, | ||
292 | sample_prev: torch.FloatTensor, | ||
293 | derivative: torch.FloatTensor, | ||
294 | generator: None, | ||
295 | return_dict: bool = True, | ||
296 | ) -> Union[SchedulerOutput, Tuple]: | ||
297 | """ | ||
298 | Correct the predicted sample based on the output model_output of the network. TODO complete description | ||
299 | |||
300 | Args: | ||
301 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
302 | sigma_hat (`float`): TODO | ||
303 | sigma_prev (`float`): TODO | ||
304 | sample_hat (`torch.FloatTensor`): TODO | ||
305 | sample_prev (`torch.FloatTensor`): TODO | ||
306 | derivative (`torch.FloatTensor`): TODO | ||
307 | return_dict (`bool`): option for returning tuple rather than SchedulerOutput class | ||
308 | |||
309 | Returns: | ||
310 | prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO | ||
311 | |||
312 | """ | ||
313 | pred_original_sample = sample_prev + sigma_prev * model_output | ||
314 | derivative_corr = (sample_prev - pred_original_sample) / sigma_prev | ||
315 | sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) | ||
316 | |||
317 | if not return_dict: | ||
318 | return (sample_prev, derivative) | ||
319 | |||
320 | return SchedulerOutput(prev_sample=sample_prev) | ||
321 | |||
322 | def add_noise(self, original_samples, noise, timesteps): | ||
323 | raise NotImplementedError() | ||