diff options
| -rw-r--r-- | infer.py | 154 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 169 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 6 |
3 files changed, 113 insertions, 216 deletions
| @@ -1,18 +1,21 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import sys | ||
| 5 | import shlex | ||
| 6 | import cmd | ||
| 4 | from pathlib import Path | 7 | from pathlib import Path |
| 5 | from torch import autocast | 8 | from torch import autocast |
| 6 | import torch | 9 | import torch |
| 7 | import json | 10 | import json |
| 8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
| 10 | from slugify import slugify | 13 | from slugify import slugify |
| 11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
| 12 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 13 | 16 | ||
| 14 | 17 | ||
| 15 | def parse_args(): | 18 | def create_args_parser(): |
| 16 | parser = argparse.ArgumentParser( | 19 | parser = argparse.ArgumentParser( |
| 17 | description="Simple example of a training script." | 20 | description="Simple example of a training script." |
| 18 | ) | 21 | ) |
| @@ -22,6 +25,30 @@ def parse_args(): | |||
| 22 | default=None, | 25 | default=None, |
| 23 | ) | 26 | ) |
| 24 | parser.add_argument( | 27 | parser.add_argument( |
| 28 | "--scheduler", | ||
| 29 | type=str, | ||
| 30 | choices=["plms", "ddim", "klms", "euler_a"], | ||
| 31 | default="euler_a", | ||
| 32 | ) | ||
| 33 | parser.add_argument( | ||
| 34 | "--output_dir", | ||
| 35 | type=str, | ||
| 36 | default="output/inference", | ||
| 37 | ) | ||
| 38 | parser.add_argument( | ||
| 39 | "--config", | ||
| 40 | type=str, | ||
| 41 | default=None, | ||
| 42 | ) | ||
| 43 | |||
| 44 | return parser | ||
| 45 | |||
| 46 | |||
| 47 | def create_cmd_parser(): | ||
| 48 | parser = argparse.ArgumentParser( | ||
| 49 | description="Simple example of a training script." | ||
| 50 | ) | ||
| 51 | parser.add_argument( | ||
| 25 | "--prompt", | 52 | "--prompt", |
| 26 | type=str, | 53 | type=str, |
| 27 | default=None, | 54 | default=None, |
| @@ -49,28 +76,17 @@ def parse_args(): | |||
| 49 | parser.add_argument( | 76 | parser.add_argument( |
| 50 | "--batch_num", | 77 | "--batch_num", |
| 51 | type=int, | 78 | type=int, |
| 52 | default=50, | 79 | default=1, |
| 53 | ) | 80 | ) |
| 54 | parser.add_argument( | 81 | parser.add_argument( |
| 55 | "--steps", | 82 | "--steps", |
| 56 | type=int, | 83 | type=int, |
| 57 | default=120, | 84 | default=70, |
| 58 | ) | ||
| 59 | parser.add_argument( | ||
| 60 | "--scheduler", | ||
| 61 | type=str, | ||
| 62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
| 63 | default="euler_a", | ||
| 64 | ) | 85 | ) |
| 65 | parser.add_argument( | 86 | parser.add_argument( |
| 66 | "--guidance_scale", | 87 | "--guidance_scale", |
| 67 | type=int, | 88 | type=int, |
| 68 | default=7.5, | 89 | default=7, |
| 69 | ) | ||
| 70 | parser.add_argument( | ||
| 71 | "--clip_guidance_scale", | ||
| 72 | type=int, | ||
| 73 | default=100, | ||
| 74 | ) | 90 | ) |
| 75 | parser.add_argument( | 91 | parser.add_argument( |
| 76 | "--seed", | 92 | "--seed", |
| @@ -78,21 +94,21 @@ def parse_args(): | |||
| 78 | default=torch.random.seed(), | 94 | default=torch.random.seed(), |
| 79 | ) | 95 | ) |
| 80 | parser.add_argument( | 96 | parser.add_argument( |
| 81 | "--output_dir", | ||
| 82 | type=str, | ||
| 83 | default="output/inference", | ||
| 84 | ) | ||
| 85 | parser.add_argument( | ||
| 86 | "--config", | 97 | "--config", |
| 87 | type=str, | 98 | type=str, |
| 88 | default=None, | 99 | default=None, |
| 89 | ) | 100 | ) |
| 90 | 101 | ||
| 91 | args = parser.parse_args() | 102 | return parser |
| 103 | |||
| 104 | |||
| 105 | def run_parser(parser, input=None): | ||
| 106 | args = parser.parse_known_args(input)[0] | ||
| 107 | |||
| 92 | if args.config is not None: | 108 | if args.config is not None: |
| 93 | with open(args.config, 'rt') as f: | 109 | with open(args.config, 'rt') as f: |
| 94 | args = parser.parse_args( | 110 | args = parser.parse_known_args( |
| 95 | namespace=argparse.Namespace(**json.load(f)["args"])) | 111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
| 96 | 112 | ||
| 97 | return args | 113 | return args |
| 98 | 114 | ||
| @@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}): | |||
| 104 | json.dump(info, f, indent=4) | 120 | json.dump(info, f, indent=4) |
| 105 | 121 | ||
| 106 | 122 | ||
| 107 | def gen(args, output_dir): | 123 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): |
| 108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 124 | print("Loading Stable Diffusion pipeline...") |
| 109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
| 110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
| 112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
| 113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( | ||
| 114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
| 115 | 125 | ||
| 116 | if args.scheduler == "plms": | 126 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
| 127 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | ||
| 128 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | ||
| 129 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | ||
| 130 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
| 131 | |||
| 132 | if scheduler == "plms": | ||
| 117 | scheduler = PNDMScheduler( | 133 | scheduler = PNDMScheduler( |
| 118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 134 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 119 | ) | 135 | ) |
| 120 | elif args.scheduler == "klms": | 136 | elif scheduler == "klms": |
| 121 | scheduler = LMSDiscreteScheduler( | 137 | scheduler = LMSDiscreteScheduler( |
| 122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 138 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 123 | ) | 139 | ) |
| 124 | elif args.scheduler == "ddim": | 140 | elif scheduler == "ddim": |
| 125 | scheduler = DDIMScheduler( | 141 | scheduler = DDIMScheduler( |
| 126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 142 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
| 127 | ) | 143 | ) |
| @@ -135,13 +151,24 @@ def gen(args, output_dir): | |||
| 135 | vae=vae, | 151 | vae=vae, |
| 136 | unet=unet, | 152 | unet=unet, |
| 137 | tokenizer=tokenizer, | 153 | tokenizer=tokenizer, |
| 138 | clip_model=clip_model, | ||
| 139 | scheduler=scheduler, | 154 | scheduler=scheduler, |
| 140 | feature_extractor=feature_extractor | 155 | feature_extractor=feature_extractor |
| 141 | ) | 156 | ) |
| 142 | pipeline.enable_attention_slicing() | 157 | pipeline.enable_attention_slicing() |
| 143 | pipeline.to("cuda") | 158 | pipeline.to("cuda") |
| 144 | 159 | ||
| 160 | print("Pipeline loaded.") | ||
| 161 | |||
| 162 | return pipeline | ||
| 163 | |||
| 164 | |||
| 165 | def generate(output_dir, pipeline, args): | ||
| 166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 169 | |||
| 170 | save_args(output_dir, args) | ||
| 171 | |||
| 145 | with autocast("cuda"): | 172 | with autocast("cuda"): |
| 146 | for i in range(args.batch_num): | 173 | for i in range(args.batch_num): |
| 147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| @@ -152,7 +179,6 @@ def gen(args, output_dir): | |||
| 152 | negative_prompt=args.negative_prompt, | 179 | negative_prompt=args.negative_prompt, |
| 153 | num_inference_steps=args.steps, | 180 | num_inference_steps=args.steps, |
| 154 | guidance_scale=args.guidance_scale, | 181 | guidance_scale=args.guidance_scale, |
| 155 | clip_guidance_scale=args.clip_guidance_scale, | ||
| 156 | generator=generator, | 182 | generator=generator, |
| 157 | ).images | 183 | ).images |
| 158 | 184 | ||
| @@ -160,18 +186,56 @@ def gen(args, output_dir): | |||
| 160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
| 161 | 187 | ||
| 162 | 188 | ||
| 163 | def main(): | 189 | class CmdParse(cmd.Cmd): |
| 164 | args = parse_args() | 190 | prompt = 'dream> ' |
| 191 | commands = [] | ||
| 165 | 192 | ||
| 166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 193 | def __init__(self, output_dir, pipeline, parser): |
| 167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 194 | super().__init__() |
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 169 | 195 | ||
| 170 | save_args(output_dir, args) | 196 | self.output_dir = output_dir |
| 197 | self.pipeline = pipeline | ||
| 198 | self.parser = parser | ||
| 199 | |||
| 200 | def default(self, line): | ||
| 201 | line = line.replace("'", "\\'") | ||
| 202 | |||
| 203 | try: | ||
| 204 | elements = shlex.split(line) | ||
| 205 | except ValueError as e: | ||
| 206 | print(str(e)) | ||
| 207 | |||
| 208 | if elements[0] == 'q': | ||
| 209 | return True | ||
| 210 | |||
| 211 | try: | ||
| 212 | args = run_parser(self.parser, elements) | ||
| 213 | except SystemExit: | ||
| 214 | self.parser.print_help() | ||
| 215 | |||
| 216 | if len(args.prompt) == 0: | ||
| 217 | print('Try again with a prompt!') | ||
| 218 | |||
| 219 | try: | ||
| 220 | generate(self.output_dir, self.pipeline, args) | ||
| 221 | except KeyboardInterrupt: | ||
| 222 | print('Generation cancelled.') | ||
| 223 | |||
| 224 | def do_exit(self, line): | ||
| 225 | return True | ||
| 226 | |||
| 227 | |||
| 228 | def main(): | ||
| 229 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | ||
| 171 | 230 | ||
| 172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 231 | args_parser = create_args_parser() |
| 232 | args = run_parser(args_parser) | ||
| 233 | output_dir = Path(args.output_dir) | ||
| 173 | 234 | ||
| 174 | gen(args, output_dir) | 235 | pipeline = create_pipeline(args.model, args.scheduler) |
| 236 | cmd_parser = create_cmd_parser() | ||
| 237 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | ||
| 238 | cmd_prompt.cmdloop() | ||
| 175 | 239 | ||
| 176 | 240 | ||
| 177 | if __name__ == "__main__": | 241 | if __name__ == "__main__": |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index 306d9a9..ddf7ce1 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
| @@ -17,53 +17,14 @@ from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | |||
| 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 18 | 18 | ||
| 19 | 19 | ||
| 20 | class MakeCutouts(nn.Module): | ||
| 21 | def __init__(self, cut_size, cut_power=1.0): | ||
| 22 | super().__init__() | ||
| 23 | |||
| 24 | self.cut_size = cut_size | ||
| 25 | self.cut_power = cut_power | ||
| 26 | |||
| 27 | def forward(self, pixel_values, num_cutouts): | ||
| 28 | sideY, sideX = pixel_values.shape[2:4] | ||
| 29 | max_size = min(sideX, sideY) | ||
| 30 | min_size = min(sideX, sideY, self.cut_size) | ||
| 31 | cutouts = [] | ||
| 32 | for _ in range(num_cutouts): | ||
| 33 | size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | ||
| 34 | offsetx = torch.randint(0, sideX - size + 1, ()) | ||
| 35 | offsety = torch.randint(0, sideY - size + 1, ()) | ||
| 36 | cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] | ||
| 37 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | ||
| 38 | return torch.cat(cutouts) | ||
| 39 | |||
| 40 | |||
| 41 | def spherical_dist_loss(x, y): | ||
| 42 | x = F.normalize(x, dim=-1) | ||
| 43 | y = F.normalize(y, dim=-1) | ||
| 44 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | ||
| 45 | |||
| 46 | |||
| 47 | def set_requires_grad(model, value): | ||
| 48 | for param in model.parameters(): | ||
| 49 | param.requires_grad = value | ||
| 50 | |||
| 51 | |||
| 52 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): |
| 53 | """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | ||
| 54 | - https://github.com/Jack000/glid-3-xl | ||
| 55 | - https://github.dev/crowsonkb/k-diffusion | ||
| 56 | """ | ||
| 57 | |||
| 58 | def __init__( | 21 | def __init__( |
| 59 | self, | 22 | self, |
| 60 | vae: AutoencoderKL, | 23 | vae: AutoencoderKL, |
| 61 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
| 62 | clip_model: CLIPModel, | ||
| 63 | tokenizer: CLIPTokenizer, | 25 | tokenizer: CLIPTokenizer, |
| 64 | unet: UNet2DConditionModel, | 26 | unet: UNet2DConditionModel, |
| 65 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | 27 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], |
| 66 | feature_extractor: CLIPFeatureExtractor, | ||
| 67 | **kwargs, | 28 | **kwargs, |
| 68 | ): | 29 | ): |
| 69 | super().__init__() | 30 | super().__init__() |
| @@ -85,19 +46,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 85 | self.register_modules( | 46 | self.register_modules( |
| 86 | vae=vae, | 47 | vae=vae, |
| 87 | text_encoder=text_encoder, | 48 | text_encoder=text_encoder, |
| 88 | clip_model=clip_model, | ||
| 89 | tokenizer=tokenizer, | 49 | tokenizer=tokenizer, |
| 90 | unet=unet, | 50 | unet=unet, |
| 91 | scheduler=scheduler, | 51 | scheduler=scheduler, |
| 92 | feature_extractor=feature_extractor, | ||
| 93 | ) | 52 | ) |
| 94 | 53 | ||
| 95 | self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | ||
| 96 | self.make_cutouts = MakeCutouts(feature_extractor.size) | ||
| 97 | |||
| 98 | set_requires_grad(self.text_encoder, False) | ||
| 99 | set_requires_grad(self.clip_model, False) | ||
| 100 | |||
| 101 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 54 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 102 | r""" | 55 | r""" |
| 103 | Enable sliced attention computation. | 56 | Enable sliced attention computation. |
| @@ -125,87 +78,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 125 | # set slice_size = `None` to disable `attention slicing` | 78 | # set slice_size = `None` to disable `attention slicing` |
| 126 | self.enable_attention_slicing(None) | 79 | self.enable_attention_slicing(None) |
| 127 | 80 | ||
| 128 | def freeze_vae(self): | ||
| 129 | set_requires_grad(self.vae, False) | ||
| 130 | |||
| 131 | def unfreeze_vae(self): | ||
| 132 | set_requires_grad(self.vae, True) | ||
| 133 | |||
| 134 | def freeze_unet(self): | ||
| 135 | set_requires_grad(self.unet, False) | ||
| 136 | |||
| 137 | def unfreeze_unet(self): | ||
| 138 | set_requires_grad(self.unet, True) | ||
| 139 | |||
| 140 | @torch.enable_grad() | ||
| 141 | def cond_fn( | ||
| 142 | self, | ||
| 143 | latents, | ||
| 144 | timestep, | ||
| 145 | index, | ||
| 146 | text_embeddings, | ||
| 147 | noise_pred_original, | ||
| 148 | text_embeddings_clip, | ||
| 149 | clip_guidance_scale, | ||
| 150 | num_cutouts, | ||
| 151 | use_cutouts=True, | ||
| 152 | ): | ||
| 153 | latents = latents.detach().requires_grad_() | ||
| 154 | |||
| 155 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 156 | sigma = self.scheduler.sigmas[index] | ||
| 157 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
| 158 | latent_model_input = latents / ((sigma**2 + 1) ** 0.5) | ||
| 159 | else: | ||
| 160 | latent_model_input = latents | ||
| 161 | |||
| 162 | # predict the noise residual | ||
| 163 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | ||
| 164 | |||
| 165 | if isinstance(self.scheduler, PNDMScheduler): | ||
| 166 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
| 167 | beta_prod_t = 1 - alpha_prod_t | ||
| 168 | # compute predicted original sample from predicted noise also called | ||
| 169 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
| 170 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | ||
| 171 | |||
| 172 | fac = torch.sqrt(beta_prod_t) | ||
| 173 | sample = pred_original_sample * (fac) + latents * (1 - fac) | ||
| 174 | elif isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 175 | sigma = self.scheduler.sigmas[index] | ||
| 176 | sample = latents - sigma * noise_pred | ||
| 177 | else: | ||
| 178 | raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | ||
| 179 | |||
| 180 | sample = 1 / 0.18215 * sample | ||
| 181 | image = self.vae.decode(sample).sample | ||
| 182 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 183 | |||
| 184 | if use_cutouts: | ||
| 185 | image = self.make_cutouts(image, num_cutouts) | ||
| 186 | else: | ||
| 187 | image = transforms.Resize(self.feature_extractor.size)(image) | ||
| 188 | image = self.normalize(image) | ||
| 189 | |||
| 190 | image_embeddings_clip = self.clip_model.get_image_features(image).float() | ||
| 191 | image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
| 192 | |||
| 193 | if use_cutouts: | ||
| 194 | dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | ||
| 195 | dists = dists.view([num_cutouts, sample.shape[0], -1]) | ||
| 196 | loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | ||
| 197 | else: | ||
| 198 | loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | ||
| 199 | |||
| 200 | grads = -torch.autograd.grad(loss, latents)[0] | ||
| 201 | |||
| 202 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
| 203 | latents = latents.detach() + grads * (sigma**2) | ||
| 204 | noise_pred = noise_pred_original | ||
| 205 | else: | ||
| 206 | noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | ||
| 207 | return noise_pred, latents | ||
| 208 | |||
| 209 | @torch.no_grad() | 81 | @torch.no_grad() |
| 210 | def __call__( | 82 | def __call__( |
| 211 | self, | 83 | self, |
| @@ -216,10 +88,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 216 | num_inference_steps: Optional[int] = 50, | 88 | num_inference_steps: Optional[int] = 50, |
| 217 | guidance_scale: Optional[float] = 7.5, | 89 | guidance_scale: Optional[float] = 7.5, |
| 218 | eta: Optional[float] = 0.0, | 90 | eta: Optional[float] = 0.0, |
| 219 | clip_guidance_scale: Optional[float] = 100, | ||
| 220 | clip_prompt: Optional[Union[str, List[str]]] = None, | ||
| 221 | num_cutouts: Optional[int] = 4, | ||
| 222 | use_cutouts: Optional[bool] = True, | ||
| 223 | generator: Optional[torch.Generator] = None, | 91 | generator: Optional[torch.Generator] = None, |
| 224 | latents: Optional[torch.FloatTensor] = None, | 92 | latents: Optional[torch.FloatTensor] = None, |
| 225 | output_type: Optional[str] = "pil", | 93 | output_type: Optional[str] = "pil", |
| @@ -305,24 +173,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 305 | "The following part of your input was truncated because CLIP can only handle sequences up to" | 173 | "The following part of your input was truncated because CLIP can only handle sequences up to" |
| 306 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 174 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
| 307 | ) | 175 | ) |
| 176 | print(f"Too many tokens: {removed_text}") | ||
| 308 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 177 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
| 309 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 178 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] |
| 310 | 179 | ||
| 311 | if clip_guidance_scale > 0: | ||
| 312 | if clip_prompt is not None: | ||
| 313 | clip_text_inputs = self.tokenizer( | ||
| 314 | clip_prompt, | ||
| 315 | padding="max_length", | ||
| 316 | max_length=self.tokenizer.model_max_length, | ||
| 317 | truncation=True, | ||
| 318 | return_tensors="pt", | ||
| 319 | ) | ||
| 320 | clip_text_input_ids = clip_text_inputs.input_ids | ||
| 321 | else: | ||
| 322 | clip_text_input_ids = text_input_ids | ||
| 323 | text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) | ||
| 324 | text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
| 325 | |||
| 326 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 180 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 327 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 181 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| 328 | # corresponds to doing no classifier free guidance. | 182 | # corresponds to doing no classifier free guidance. |
| @@ -357,7 +211,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 357 | else: | 211 | else: |
| 358 | if latents.shape != latents_shape: | 212 | if latents.shape != latents_shape: |
| 359 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
| 360 | latents = latents.to(self.device) | 214 | latents = latents.to(self.device) |
| 361 | 215 | ||
| 362 | # set timesteps | 216 | # set timesteps |
| 363 | self.scheduler.set_timesteps(num_inference_steps) | 217 | self.scheduler.set_timesteps(num_inference_steps) |
| @@ -414,23 +268,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 414 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 268 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 415 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 269 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 416 | 270 | ||
| 417 | # perform clip guidance | ||
| 418 | if clip_guidance_scale > 0: | ||
| 419 | text_embeddings_for_guidance = ( | ||
| 420 | text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | ||
| 421 | ) | ||
| 422 | noise_pred, latents = self.cond_fn( | ||
| 423 | latents, | ||
| 424 | t, | ||
| 425 | i, | ||
| 426 | text_embeddings_for_guidance, | ||
| 427 | noise_pred, | ||
| 428 | text_embeddings_clip, | ||
| 429 | clip_guidance_scale, | ||
| 430 | num_cutouts, | ||
| 431 | use_cutouts, | ||
| 432 | ) | ||
| 433 | |||
| 434 | # compute the previous noisy sample x_t -> x_t-1 | 271 | # compute the previous noisy sample x_t -> x_t-1 |
| 435 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
| 436 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 57a56de..29ebd07 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 216 | 216 | ||
| 217 | self.num_inference_steps = num_inference_steps | 217 | self.num_inference_steps = num_inference_steps |
| 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) |
| 220 | self.timesteps = self.sigmas | 220 | self.timesteps = self.sigmas |
| 221 | 221 | ||
| 222 | def add_noise_to_input( | 222 | def add_noise_to_input( |
| @@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 272 | """ | 272 | """ |
| 273 | latents = sample | 273 | latents = sample |
| 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) |
| 275 | |||
| 276 | # if callback is not None: | ||
| 277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
| 278 | d = to_d(latents, timestep, model_output) | 275 | d = to_d(latents, timestep, model_output) |
| 279 | # Euler method | ||
| 280 | dt = sigma_down - timestep | 276 | dt = sigma_down - timestep |
| 281 | latents = latents + d * dt | 277 | latents = latents + d * dt |
| 282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 278 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |
