summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
committerVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
commit9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch)
treead186862f5095663966dd1d42455023080aa0c4e
parentBetter sample file structure (diff)
downloadtextual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip
Added custom SD pipeline + euler_a scheduler
-rw-r--r--.gitignore2
-rw-r--r--dreambooth.py12
-rw-r--r--infer.py111
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py457
-rw-r--r--schedulers/scheduling_euler_a.py323
5 files changed, 869 insertions, 36 deletions
diff --git a/.gitignore b/.gitignore
index 4456cef..35b4c22 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,5 +160,5 @@ cython_debug/
160#.idea/ 160#.idea/
161 161
162output/ 162output/
163conf*.json 163conf/
164v1-inference.yaml* 164v1-inference.yaml*
diff --git a/dreambooth.py b/dreambooth.py
index 39c4851..4d7366c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -59,7 +59,7 @@ def parse_args():
59 parser.add_argument( 59 parser.add_argument(
60 "--repeats", 60 "--repeats",
61 type=int, 61 type=int,
62 default=100, 62 default=1,
63 help="How many times to repeat the training data." 63 help="How many times to repeat the training data."
64 ) 64 )
65 parser.add_argument( 65 parser.add_argument(
@@ -375,7 +375,6 @@ class Checkpointer:
375 @torch.no_grad() 375 @torch.no_grad()
376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
377 samples_path = Path(self.output_dir).joinpath("samples") 377 samples_path = Path(self.output_dir).joinpath("samples")
378 samples_path.mkdir(parents=True, exist_ok=True)
379 378
380 unwrapped = self.accelerator.unwrap_model(self.unet) 379 unwrapped = self.accelerator.unwrap_model(self.unet)
381 pipeline = StableDiffusionPipeline( 380 pipeline = StableDiffusionPipeline(
@@ -403,6 +402,7 @@ class Checkpointer:
403 402
404 all_samples = [] 403 all_samples = []
405 file_path = samples_path.joinpath("stable", f"step_{step}.png") 404 file_path = samples_path.joinpath("stable", f"step_{step}.png")
405 file_path.parent.mkdir(parents=True, exist_ok=True)
406 406
407 data_enum = enumerate(val_data) 407 data_enum = enumerate(val_data)
408 408
@@ -436,6 +436,7 @@ class Checkpointer:
436 for data, pool in [(val_data, "val"), (train_data, "train")]: 436 for data, pool in [(val_data, "val"), (train_data, "train")]:
437 all_samples = [] 437 all_samples = []
438 file_path = samples_path.joinpath(pool, f"step_{step}.png") 438 file_path = samples_path.joinpath(pool, f"step_{step}.png")
439 file_path.parent.mkdir(parents=True, exist_ok=True)
439 440
440 data_enum = enumerate(data) 441 data_enum = enumerate(data)
441 442
@@ -496,11 +497,15 @@ def main():
496 cur_class_images = len(list(class_images_dir.iterdir())) 497 cur_class_images = len(list(class_images_dir.iterdir()))
497 498
498 if cur_class_images < args.num_class_images: 499 if cur_class_images < args.num_class_images:
499 torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 500 torch_dtype = torch.float32
501 if accelerator.device.type == "cuda":
502 torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision]
503
500 pipeline = StableDiffusionPipeline.from_pretrained( 504 pipeline = StableDiffusionPipeline.from_pretrained(
501 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 505 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
502 pipeline.enable_attention_slicing() 506 pipeline.enable_attention_slicing()
503 pipeline.set_progress_bar_config(disable=True) 507 pipeline.set_progress_bar_config(disable=True)
508 pipeline.to(accelerator.device)
504 509
505 num_new_images = args.num_class_images - cur_class_images 510 num_new_images = args.num_class_images - cur_class_images
506 logger.info(f"Number of class images to sample: {num_new_images}.") 511 logger.info(f"Number of class images to sample: {num_new_images}.")
@@ -509,7 +514,6 @@ def main():
509 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 514 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
510 515
511 sample_dataloader = accelerator.prepare(sample_dataloader) 516 sample_dataloader = accelerator.prepare(sample_dataloader)
512 pipeline.to(accelerator.device)
513 517
514 for example in tqdm( 518 for example in tqdm(
515 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 519 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
diff --git a/infer.py b/infer.py
index f2007e9..de3d792 100644
--- a/infer.py
+++ b/infer.py
@@ -1,18 +1,15 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging
3from pathlib import Path 4from pathlib import Path
4from torch import autocast 5from torch import autocast
5from diffusers import StableDiffusionPipeline
6import torch 6import torch
7import json 7import json
8from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler 8from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
9from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 9from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify 10from slugify import slugify
11from pipelines.stable_diffusion.no_check import NoCheck 11from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
12 12from schedulers.scheduling_euler_a import EulerAScheduler
13model_id = "path-to-your-trained-model"
14
15prompt = "A photo of sks dog in a bucket"
16 13
17 14
18def parse_args(): 15def parse_args():
@@ -30,6 +27,21 @@ def parse_args():
30 default=None, 27 default=None,
31 ) 28 )
32 parser.add_argument( 29 parser.add_argument(
30 "--negative_prompt",
31 type=str,
32 default=None,
33 )
34 parser.add_argument(
35 "--width",
36 type=int,
37 default=512,
38 )
39 parser.add_argument(
40 "--height",
41 type=int,
42 default=512,
43 )
44 parser.add_argument(
33 "--batch_size", 45 "--batch_size",
34 type=int, 46 type=int,
35 default=1, 47 default=1,
@@ -42,17 +54,28 @@ def parse_args():
42 parser.add_argument( 54 parser.add_argument(
43 "--steps", 55 "--steps",
44 type=int, 56 type=int,
45 default=80, 57 default=120,
46 ) 58 )
47 parser.add_argument( 59 parser.add_argument(
48 "--scale", 60 "--scheduler",
61 type=str,
62 choices=["plms", "ddim", "klms", "euler_a"],
63 default="euler_a",
64 )
65 parser.add_argument(
66 "--guidance_scale",
49 type=int, 67 type=int,
50 default=7.5, 68 default=7.5,
51 ) 69 )
52 parser.add_argument( 70 parser.add_argument(
71 "--clip_guidance_scale",
72 type=int,
73 default=100,
74 )
75 parser.add_argument(
53 "--seed", 76 "--seed",
54 type=int, 77 type=int,
55 default=None, 78 default=torch.random.seed(),
56 ) 79 )
57 parser.add_argument( 80 parser.add_argument(
58 "--output_dir", 81 "--output_dir",
@@ -81,31 +104,39 @@ def save_args(basepath, args, extra={}):
81 json.dump(info, f, indent=4) 104 json.dump(info, f, indent=4)
82 105
83 106
84def main(): 107def gen(args, output_dir):
85 args = parse_args()
86
87 seed = args.seed or torch.random.seed()
88
89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
90 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
91 output_dir.mkdir(parents=True, exist_ok=True)
92 save_args(output_dir, args)
93
94 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) 108 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
95 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) 109 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
110 clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
96 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) 111 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
97 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) 112 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
98 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) 113 feature_extractor = CLIPFeatureExtractor.from_pretrained(
114 "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16)
99 115
100 pipeline = StableDiffusionPipeline( 116 if args.scheduler == "plms":
117 scheduler = PNDMScheduler(
118 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
119 )
120 elif args.scheduler == "klms":
121 scheduler = LMSDiscreteScheduler(
122 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
123 )
124 elif args.scheduler == "ddim":
125 scheduler = DDIMScheduler(
126 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
127 )
128 else:
129 scheduler = EulerAScheduler(
130 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
131 )
132
133 pipeline = CLIPGuidedStableDiffusion(
101 text_encoder=text_encoder, 134 text_encoder=text_encoder,
102 vae=vae, 135 vae=vae,
103 unet=unet, 136 unet=unet,
104 tokenizer=tokenizer, 137 tokenizer=tokenizer,
105 scheduler=PNDMScheduler( 138 clip_model=clip_model,
106 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 139 scheduler=scheduler,
107 ),
108 safety_checker=NoCheck(),
109 feature_extractor=feature_extractor 140 feature_extractor=feature_extractor
110 ) 141 )
111 pipeline.enable_attention_slicing() 142 pipeline.enable_attention_slicing()
@@ -113,16 +144,34 @@ def main():
113 144
114 with autocast("cuda"): 145 with autocast("cuda"):
115 for i in range(args.batch_num): 146 for i in range(args.batch_num):
116 generator = torch.Generator(device="cuda").manual_seed(seed + i) 147 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
117 images = pipeline( 148 images = pipeline(
118 [args.prompt] * args.batch_size, 149 prompt=[args.prompt] * args.batch_size,
150 height=args.height,
151 width=args.width,
152 negative_prompt=args.negative_prompt,
119 num_inference_steps=args.steps, 153 num_inference_steps=args.steps,
120 guidance_scale=args.scale, 154 guidance_scale=args.guidance_scale,
155 clip_guidance_scale=args.clip_guidance_scale,
121 generator=generator, 156 generator=generator,
122 ).images 157 ).images
123 158
124 for j, image in enumerate(images): 159 for j, image in enumerate(images):
125 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 160 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
161
162
163def main():
164 args = parse_args()
165
166 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
167 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True)
169
170 save_args(output_dir, args)
171
172 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
173
174 gen(args, output_dir)
126 175
127 176
128if __name__ == "__main__": 177if __name__ == "__main__":
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
new file mode 100644
index 0000000..306d9a9
--- /dev/null
+++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
@@ -0,0 +1,457 @@
1import inspect
2import warnings
3from typing import List, Optional, Union
4
5import torch
6from torch import nn
7from torch.nn import functional as F
8
9from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging
13from torchvision import transforms
14from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
16
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
19
20class MakeCutouts(nn.Module):
21 def __init__(self, cut_size, cut_power=1.0):
22 super().__init__()
23
24 self.cut_size = cut_size
25 self.cut_power = cut_power
26
27 def forward(self, pixel_values, num_cutouts):
28 sideY, sideX = pixel_values.shape[2:4]
29 max_size = min(sideX, sideY)
30 min_size = min(sideX, sideY, self.cut_size)
31 cutouts = []
32 for _ in range(num_cutouts):
33 size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
34 offsetx = torch.randint(0, sideX - size + 1, ())
35 offsety = torch.randint(0, sideY - size + 1, ())
36 cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size]
37 cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
38 return torch.cat(cutouts)
39
40
41def spherical_dist_loss(x, y):
42 x = F.normalize(x, dim=-1)
43 y = F.normalize(y, dim=-1)
44 return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
45
46
47def set_requires_grad(model, value):
48 for param in model.parameters():
49 param.requires_grad = value
50
51
52class CLIPGuidedStableDiffusion(DiffusionPipeline):
53 """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
54 - https://github.com/Jack000/glid-3-xl
55 - https://github.dev/crowsonkb/k-diffusion
56 """
57
58 def __init__(
59 self,
60 vae: AutoencoderKL,
61 text_encoder: CLIPTextModel,
62 clip_model: CLIPModel,
63 tokenizer: CLIPTokenizer,
64 unet: UNet2DConditionModel,
65 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
66 feature_extractor: CLIPFeatureExtractor,
67 **kwargs,
68 ):
69 super().__init__()
70
71 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
72 warnings.warn(
73 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
74 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
75 "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
76 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
77 " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
78 " file",
79 DeprecationWarning,
80 )
81 new_config = dict(scheduler.config)
82 new_config["steps_offset"] = 1
83 scheduler._internal_dict = FrozenDict(new_config)
84
85 self.register_modules(
86 vae=vae,
87 text_encoder=text_encoder,
88 clip_model=clip_model,
89 tokenizer=tokenizer,
90 unet=unet,
91 scheduler=scheduler,
92 feature_extractor=feature_extractor,
93 )
94
95 self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
96 self.make_cutouts = MakeCutouts(feature_extractor.size)
97
98 set_requires_grad(self.text_encoder, False)
99 set_requires_grad(self.clip_model, False)
100
101 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
102 r"""
103 Enable sliced attention computation.
104
105 When this option is enabled, the attention module will split the input tensor in slices, to compute attention
106 in several steps. This is useful to save some memory in exchange for a small speed decrease.
107
108 Args:
109 slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
110 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
111 a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
112 `attention_head_dim` must be a multiple of `slice_size`.
113 """
114 if slice_size == "auto":
115 # half the attention head size is usually a good trade-off between
116 # speed and memory
117 slice_size = self.unet.config.attention_head_dim // 2
118 self.unet.set_attention_slice(slice_size)
119
120 def disable_attention_slicing(self):
121 r"""
122 Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
123 back to computing attention in one step.
124 """
125 # set slice_size = `None` to disable `attention slicing`
126 self.enable_attention_slicing(None)
127
128 def freeze_vae(self):
129 set_requires_grad(self.vae, False)
130
131 def unfreeze_vae(self):
132 set_requires_grad(self.vae, True)
133
134 def freeze_unet(self):
135 set_requires_grad(self.unet, False)
136
137 def unfreeze_unet(self):
138 set_requires_grad(self.unet, True)
139
140 @torch.enable_grad()
141 def cond_fn(
142 self,
143 latents,
144 timestep,
145 index,
146 text_embeddings,
147 noise_pred_original,
148 text_embeddings_clip,
149 clip_guidance_scale,
150 num_cutouts,
151 use_cutouts=True,
152 ):
153 latents = latents.detach().requires_grad_()
154
155 if isinstance(self.scheduler, LMSDiscreteScheduler):
156 sigma = self.scheduler.sigmas[index]
157 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
158 latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
159 else:
160 latent_model_input = latents
161
162 # predict the noise residual
163 noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
164
165 if isinstance(self.scheduler, PNDMScheduler):
166 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
167 beta_prod_t = 1 - alpha_prod_t
168 # compute predicted original sample from predicted noise also called
169 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
170 pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
171
172 fac = torch.sqrt(beta_prod_t)
173 sample = pred_original_sample * (fac) + latents * (1 - fac)
174 elif isinstance(self.scheduler, LMSDiscreteScheduler):
175 sigma = self.scheduler.sigmas[index]
176 sample = latents - sigma * noise_pred
177 else:
178 raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
179
180 sample = 1 / 0.18215 * sample
181 image = self.vae.decode(sample).sample
182 image = (image / 2 + 0.5).clamp(0, 1)
183
184 if use_cutouts:
185 image = self.make_cutouts(image, num_cutouts)
186 else:
187 image = transforms.Resize(self.feature_extractor.size)(image)
188 image = self.normalize(image)
189
190 image_embeddings_clip = self.clip_model.get_image_features(image).float()
191 image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
192
193 if use_cutouts:
194 dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
195 dists = dists.view([num_cutouts, sample.shape[0], -1])
196 loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
197 else:
198 loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
199
200 grads = -torch.autograd.grad(loss, latents)[0]
201
202 if isinstance(self.scheduler, LMSDiscreteScheduler):
203 latents = latents.detach() + grads * (sigma**2)
204 noise_pred = noise_pred_original
205 else:
206 noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
207 return noise_pred, latents
208
209 @torch.no_grad()
210 def __call__(
211 self,
212 prompt: Union[str, List[str]],
213 negative_prompt: Optional[Union[str, List[str]]] = None,
214 height: Optional[int] = 512,
215 width: Optional[int] = 512,
216 num_inference_steps: Optional[int] = 50,
217 guidance_scale: Optional[float] = 7.5,
218 eta: Optional[float] = 0.0,
219 clip_guidance_scale: Optional[float] = 100,
220 clip_prompt: Optional[Union[str, List[str]]] = None,
221 num_cutouts: Optional[int] = 4,
222 use_cutouts: Optional[bool] = True,
223 generator: Optional[torch.Generator] = None,
224 latents: Optional[torch.FloatTensor] = None,
225 output_type: Optional[str] = "pil",
226 return_dict: bool = True,
227 ):
228 r"""
229 Function invoked when calling the pipeline for generation.
230
231 Args:
232 prompt (`str` or `List[str]`):
233 The prompt or prompts to guide the image generation.
234 height (`int`, *optional*, defaults to 512):
235 The height in pixels of the generated image.
236 width (`int`, *optional*, defaults to 512):
237 The width in pixels of the generated image.
238 num_inference_steps (`int`, *optional*, defaults to 50):
239 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
240 expense of slower inference.
241 guidance_scale (`float`, *optional*, defaults to 7.5):
242 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
243 `guidance_scale` is defined as `w` of equation 2. of [Imagen
244 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
245 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
246 usually at the expense of lower image quality.
247 eta (`float`, *optional*, defaults to 0.0):
248 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
249 [`schedulers.DDIMScheduler`], will be ignored for others.
250 generator (`torch.Generator`, *optional*):
251 A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
252 deterministic.
253 latents (`torch.FloatTensor`, *optional*):
254 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
255 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
256 tensor will ge generated by sampling using the supplied random `generator`.
257 output_type (`str`, *optional*, defaults to `"pil"`):
258 The output format of the generate image. Choose between
259 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
260 return_dict (`bool`, *optional*, defaults to `True`):
261 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
262 plain tuple.
263
264 Returns:
265 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
266 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
267 When returning a tuple, the first element is a list with the generated images, and the second element is a
268 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
269 (nsfw) content, according to the `safety_checker`.
270 """
271
272 if isinstance(prompt, str):
273 batch_size = 1
274 elif isinstance(prompt, list):
275 batch_size = len(prompt)
276 else:
277 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
278
279 if negative_prompt is None:
280 negative_prompt = [""] * batch_size
281 elif isinstance(negative_prompt, str):
282 negative_prompt = [negative_prompt] * batch_size
283 elif isinstance(negative_prompt, list):
284 if len(negative_prompt) != batch_size:
285 raise ValueError(
286 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
287 else:
288 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
289
290 if height % 8 != 0 or width % 8 != 0:
291 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
292
293 # get prompt text embeddings
294 text_inputs = self.tokenizer(
295 prompt,
296 padding="max_length",
297 max_length=self.tokenizer.model_max_length,
298 return_tensors="pt",
299 )
300 text_input_ids = text_inputs.input_ids
301
302 if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
303 removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:])
304 logger.warning(
305 "The following part of your input was truncated because CLIP can only handle sequences up to"
306 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
307 )
308 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
309 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
310
311 if clip_guidance_scale > 0:
312 if clip_prompt is not None:
313 clip_text_inputs = self.tokenizer(
314 clip_prompt,
315 padding="max_length",
316 max_length=self.tokenizer.model_max_length,
317 truncation=True,
318 return_tensors="pt",
319 )
320 clip_text_input_ids = clip_text_inputs.input_ids
321 else:
322 clip_text_input_ids = text_input_ids
323 text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device))
324 text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
325
326 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
327 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
328 # corresponds to doing no classifier free guidance.
329 do_classifier_free_guidance = guidance_scale > 1.0
330 # get unconditional embeddings for classifier free guidance
331 if do_classifier_free_guidance:
332 max_length = text_input_ids.shape[-1]
333 uncond_input = self.tokenizer(
334 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
335 )
336 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
337
338 # For classifier free guidance, we need to do two forward passes.
339 # Here we concatenate the unconditional and text embeddings into a single batch
340 # to avoid doing two forward passes
341 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
342
343 # get the initial random noise unless the user supplied it
344
345 # Unlike in other pipelines, latents need to be generated in the target device
346 # for 1-to-1 results reproducibility with the CompVis implementation.
347 # However this currently doesn't work in `mps`.
348 latents_device = "cpu" if self.device.type == "mps" else self.device
349 latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
350 if latents is None:
351 latents = torch.randn(
352 latents_shape,
353 generator=generator,
354 device=latents_device,
355 dtype=text_embeddings.dtype,
356 )
357 else:
358 if latents.shape != latents_shape:
359 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
360 latents = latents.to(self.device)
361
362 # set timesteps
363 self.scheduler.set_timesteps(num_inference_steps)
364
365 # Some schedulers like PNDM have timesteps as arrays
366 # It's more optimzed to move all timesteps to correct device beforehand
367 if torch.is_tensor(self.scheduler.timesteps):
368 timesteps_tensor = self.scheduler.timesteps.to(self.device)
369 else:
370 timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
371
372 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
373 if isinstance(self.scheduler, LMSDiscreteScheduler):
374 latents = latents * self.scheduler.sigmas[0]
375 elif isinstance(self.scheduler, EulerAScheduler):
376 sigma = self.scheduler.timesteps[0]
377 latents = latents * sigma
378
379 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
380 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
381 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
382 # and should be between [0, 1]
383 scheduler_step_args = set(inspect.signature(self.scheduler.step).parameters.keys())
384 accepts_eta = "eta" in scheduler_step_args
385 extra_step_kwargs = {}
386 if accepts_eta:
387 extra_step_kwargs["eta"] = eta
388 accepts_generator = "generator" in scheduler_step_args
389 if generator is not None and accepts_generator:
390 extra_step_kwargs["generator"] = generator
391
392 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
393 # expand the latents if we are doing classifier free guidance
394 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
395 if isinstance(self.scheduler, LMSDiscreteScheduler):
396 sigma = self.scheduler.sigmas[i]
397 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
398 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
399
400 noise_pred = None
401 if isinstance(self.scheduler, EulerAScheduler):
402 sigma = t.reshape(1)
403 sigma_in = torch.cat([sigma] * 2)
404 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
405 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
406 text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas)
407 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
408 else:
409 # predict the noise residual
410 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
411
412 # perform guidance
413 if do_classifier_free_guidance:
414 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
415 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
416
417 # perform clip guidance
418 if clip_guidance_scale > 0:
419 text_embeddings_for_guidance = (
420 text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
421 )
422 noise_pred, latents = self.cond_fn(
423 latents,
424 t,
425 i,
426 text_embeddings_for_guidance,
427 noise_pred,
428 text_embeddings_clip,
429 clip_guidance_scale,
430 num_cutouts,
431 use_cutouts,
432 )
433
434 # compute the previous noisy sample x_t -> x_t-1
435 if isinstance(self.scheduler, LMSDiscreteScheduler):
436 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
437 elif isinstance(self.scheduler, EulerAScheduler):
438 if i < self.scheduler.timesteps.shape[0] - 1: # avoid out of bound error
439 t_prev = self.scheduler.timesteps[i+1]
440 latents = self.scheduler.step(noise_pred, t, t_prev, latents, **extra_step_kwargs).prev_sample
441 else:
442 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
443
444 # scale and decode the image latents with vae
445 latents = 1 / 0.18215 * latents
446 image = self.vae.decode(latents).sample
447
448 image = (image / 2 + 0.5).clamp(0, 1)
449 image = image.cpu().permute(0, 2, 3, 1).numpy()
450
451 if output_type == "pil":
452 image = self.numpy_to_pil(image)
453
454 if not return_dict:
455 return (image, None)
456
457 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
new file mode 100644
index 0000000..57a56de
--- /dev/null
+++ b/schedulers/scheduling_euler_a.py
@@ -0,0 +1,323 @@
1
2
3import math
4import warnings
5from typing import Optional, Tuple, Union
6
7import numpy as np
8import torch
9
10from diffusers.configuration_utils import ConfigMixin, register_to_config
11from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
12
13
14'''
15helper functions: append_zero(),
16 t_to_sigma(),
17 get_sigmas(),
18 append_dims(),
19 CFGDenoiserForward(),
20 get_scalings(),
21 DSsigma_to_t(),
22 DiscreteEpsDDPMDenoiserForward(),
23 to_d(),
24 get_ancestral_step()
25need cleaning
26'''
27
28
29def append_zero(x):
30 return torch.cat([x, x.new_zeros([1])])
31
32
33def t_to_sigma(t, sigmas):
34 t = t.float()
35 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
36 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
37
38
39def get_sigmas(sigmas, n=None):
40 if n is None:
41 return append_zero(sigmas.flip(0))
42 t_max = len(sigmas) - 1 # = 999
43 t = torch.linspace(t_max, 0, n, device=sigmas.device)
44 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
45 return append_zero(t_to_sigma(t, sigmas))
46
47# from k_samplers utils.py
48
49
50def append_dims(x, target_dims):
51 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
52 dims_to_append = target_dims - x.ndim
53 if dims_to_append < 0:
54 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
55 return x[(...,) + (None,) * dims_to_append]
56
57
58def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None):
59 # x_in = torch.cat([x] * 2)#A# concat the latent
60 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
61 # cond_in = torch.cat([uncond, cond])
62 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
63 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
64 # return uncond + (cond - uncond) * cond_scale
65 noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in)
66 return noise_pred
67
68# from k_samplers sampling.py
69
70
71def to_d(x, sigma, denoised):
72 """Converts a denoiser output to a Karras ODE derivative."""
73 return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim)
74
75
76def get_scalings(sigma):
77 sigma_data = 1.
78 c_out = -sigma
79 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
80 return c_out, c_in
81
82# DiscreteSchedule DS
83
84
85def DSsigma_to_t(sigma, quantize=None, DSsigmas=None):
86 # quantize = self.quantize if quantize is None else quantize
87 quantize = False
88 dists = torch.abs(sigma - DSsigmas[:, None])
89 if quantize:
90 return torch.argmin(dists, dim=0).view(sigma.shape)
91 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
92 low, high = DSsigmas[low_idx], DSsigmas[high_idx]
93 w = (low - sigma) / (low - high)
94 w = w.clamp(0, 1)
95 t = (1 - w) * low_idx + w * high_idx
96 return t.view(sigma.shape)
97
98
99def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs):
100 sigma = sigma.to(Unet.device)
101 DSsigmas = DSsigmas.to(Unet.device)
102 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
103 # ??? what is eps?
104 # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
105 eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas),
106 encoder_hidden_states=kwargs['cond']).sample
107 return input + eps * c_out
108
109
110# from k_samplers sampling.py
111def get_ancestral_step(sigma_from, sigma_to):
112 """Calculates the noise level (sigma_down) to step down to and the amount
113 of noise to add (sigma_up) when doing an ancestral sampling step."""
114 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
115 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
116 return sigma_down, sigma_up
117
118
119'''
120Euler Ancestral Scheduler
121'''
122
123
124class EulerAScheduler(SchedulerMixin, ConfigMixin):
125 """
126 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
127 the VE column of Table 1 from [1] for reference.
128
129 [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
130 https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
131 differential equations." https://arxiv.org/abs/2011.13456
132
133 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
134 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
135 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
136 [`~ConfigMixin.from_config`] functions.
137
138 For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
139 Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
140 optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
141
142 Args:
143 sigma_min (`float`): minimum noise magnitude
144 sigma_max (`float`): maximum noise magnitude
145 s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
146 A reasonable range is [1.000, 1.011].
147 s_churn (`float`): the parameter controlling the overall amount of stochasticity.
148 A reasonable range is [0, 100].
149 s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
150 A reasonable range is [0, 10].
151 s_max (`float`): the end value of the sigma range where we add noise.
152 A reasonable range is [0.2, 80].
153
154 """
155
156 @register_to_config
157 def __init__(
158 self,
159 num_train_timesteps: int = 1000,
160 beta_start: float = 0.0001,
161 beta_end: float = 0.02,
162 beta_schedule: str = "linear",
163 trained_betas: Optional[np.ndarray] = None,
164 clip_sample: bool = True,
165 set_alpha_to_one: bool = True,
166 steps_offset: int = 0,
167 ):
168 if trained_betas is not None:
169 self.betas = torch.from_numpy(trained_betas)
170 if beta_schedule == "linear":
171 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
172 elif beta_schedule == "scaled_linear":
173 # this schedule is very specific to the latent diffusion model.
174 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
175 elif beta_schedule == "squaredcos_cap_v2":
176 # Glide cosine schedule
177 self.betas = betas_for_alpha_bar(num_train_timesteps)
178 else:
179 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
180
181 self.alphas = 1.0 - self.betas
182 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
183
184 # At every step in ddim, we are looking into the previous alphas_cumprod
185 # For the final step, there is no previous alphas_cumprod because we are already at 0
186 # `set_alpha_to_one` decides whether we set this parameter simply to one or
187 # whether we use the final alpha of the "non-previous" one.
188 self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
189
190 # setable values
191 self.num_inference_steps = None
192 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
193
194 # A# take number of steps as input
195 # A# store 1) number of steps 2) timesteps 3) schedule
196
197 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
198 """
199 Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
200
201 Args:
202 num_inference_steps (`int`):
203 the number of diffusion steps used when generating samples with a pre-trained model.
204 """
205
206 # offset = self.config.steps_offset
207
208 # if "offset" in kwargs:
209 # warnings.warn(
210 # "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
211 # " Please pass `steps_offset` to `__init__` instead.",
212 # DeprecationWarning,
213 # )
214
215 # offset = kwargs["offset"]
216
217 self.num_inference_steps = num_inference_steps
218 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
219 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
220 self.timesteps = self.sigmas
221
222 def add_noise_to_input(
223 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
224 ) -> Tuple[torch.FloatTensor, float]:
225 """
226 Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
227 higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
228
229 TODO Args:
230 """
231 if self.config.s_min <= sigma <= self.config.s_max:
232 gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
233 else:
234 gamma = 0
235
236 # sample eps ~ N(0, S_noise^2 * I)
237 eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
238 sigma_hat = sigma + gamma * sigma
239 sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
240
241 return sample_hat, sigma_hat
242
243 def step(
244 self,
245 model_output: torch.FloatTensor,
246 timestep: torch.IntTensor,
247 timestep_prev: torch.IntTensor,
248 sample: torch.FloatTensor,
249 generator: None,
250 # ,sigma_hat: float,
251 # sigma_prev: float,
252 # sample_hat: torch.FloatTensor,
253 return_dict: bool = True,
254 ) -> Union[SchedulerOutput, Tuple]:
255 """
256 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
257 process from the learned model outputs (most often the predicted noise).
258
259 Args:
260 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
261 sigma_hat (`float`): TODO
262 sigma_prev (`float`): TODO
263 sample_hat (`torch.FloatTensor`): TODO
264 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
265
266 EulerAOutput: updated sample in the diffusion chain and derivative (TODO double check).
267 Returns:
268 [`~schedulers.scheduling_karras_ve.EulerAOutput`] or `tuple`:
269 [`~schedulers.scheduling_karras_ve.EulerAOutput`] if `return_dict` is True, otherwise a `tuple`. When
270 returning a tuple, the first element is the sample tensor.
271
272 """
273 latents = sample
274 sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev)
275
276 # if callback is not None:
277 # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output})
278 d = to_d(latents, timestep, model_output)
279 # Euler method
280 dt = sigma_down - timestep
281 latents = latents + d * dt
282 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device,
283 generator=generator) * sigma_up
284 return SchedulerOutput(prev_sample=latents)
285
286 def step_correct(
287 self,
288 model_output: torch.FloatTensor,
289 sigma_hat: float,
290 sigma_prev: float,
291 sample_hat: torch.FloatTensor,
292 sample_prev: torch.FloatTensor,
293 derivative: torch.FloatTensor,
294 generator: None,
295 return_dict: bool = True,
296 ) -> Union[SchedulerOutput, Tuple]:
297 """
298 Correct the predicted sample based on the output model_output of the network. TODO complete description
299
300 Args:
301 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
302 sigma_hat (`float`): TODO
303 sigma_prev (`float`): TODO
304 sample_hat (`torch.FloatTensor`): TODO
305 sample_prev (`torch.FloatTensor`): TODO
306 derivative (`torch.FloatTensor`): TODO
307 return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
308
309 Returns:
310 prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
311
312 """
313 pred_original_sample = sample_prev + sigma_prev * model_output
314 derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
315 sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
316
317 if not return_dict:
318 return (sample_prev, derivative)
319
320 return SchedulerOutput(prev_sample=sample_prev)
321
322 def add_noise(self, original_samples, noise, timesteps):
323 raise NotImplementedError()