diff options
-rw-r--r-- | infer.py | 154 | ||||
-rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 169 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 6 |
3 files changed, 113 insertions, 216 deletions
@@ -1,18 +1,21 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import sys | ||
5 | import shlex | ||
6 | import cmd | ||
4 | from pathlib import Path | 7 | from pathlib import Path |
5 | from torch import autocast | 8 | from torch import autocast |
6 | import torch | 9 | import torch |
7 | import json | 10 | import json |
8 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
9 | from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor |
10 | from slugify import slugify | 13 | from slugify import slugify |
11 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion | 14 | from pipelines.stable_diffusion.clip_guided_stable_diffusion import CLIPGuidedStableDiffusion |
12 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
13 | 16 | ||
14 | 17 | ||
15 | def parse_args(): | 18 | def create_args_parser(): |
16 | parser = argparse.ArgumentParser( | 19 | parser = argparse.ArgumentParser( |
17 | description="Simple example of a training script." | 20 | description="Simple example of a training script." |
18 | ) | 21 | ) |
@@ -22,6 +25,30 @@ def parse_args(): | |||
22 | default=None, | 25 | default=None, |
23 | ) | 26 | ) |
24 | parser.add_argument( | 27 | parser.add_argument( |
28 | "--scheduler", | ||
29 | type=str, | ||
30 | choices=["plms", "ddim", "klms", "euler_a"], | ||
31 | default="euler_a", | ||
32 | ) | ||
33 | parser.add_argument( | ||
34 | "--output_dir", | ||
35 | type=str, | ||
36 | default="output/inference", | ||
37 | ) | ||
38 | parser.add_argument( | ||
39 | "--config", | ||
40 | type=str, | ||
41 | default=None, | ||
42 | ) | ||
43 | |||
44 | return parser | ||
45 | |||
46 | |||
47 | def create_cmd_parser(): | ||
48 | parser = argparse.ArgumentParser( | ||
49 | description="Simple example of a training script." | ||
50 | ) | ||
51 | parser.add_argument( | ||
25 | "--prompt", | 52 | "--prompt", |
26 | type=str, | 53 | type=str, |
27 | default=None, | 54 | default=None, |
@@ -49,28 +76,17 @@ def parse_args(): | |||
49 | parser.add_argument( | 76 | parser.add_argument( |
50 | "--batch_num", | 77 | "--batch_num", |
51 | type=int, | 78 | type=int, |
52 | default=50, | 79 | default=1, |
53 | ) | 80 | ) |
54 | parser.add_argument( | 81 | parser.add_argument( |
55 | "--steps", | 82 | "--steps", |
56 | type=int, | 83 | type=int, |
57 | default=120, | 84 | default=70, |
58 | ) | ||
59 | parser.add_argument( | ||
60 | "--scheduler", | ||
61 | type=str, | ||
62 | choices=["plms", "ddim", "klms", "euler_a"], | ||
63 | default="euler_a", | ||
64 | ) | 85 | ) |
65 | parser.add_argument( | 86 | parser.add_argument( |
66 | "--guidance_scale", | 87 | "--guidance_scale", |
67 | type=int, | 88 | type=int, |
68 | default=7.5, | 89 | default=7, |
69 | ) | ||
70 | parser.add_argument( | ||
71 | "--clip_guidance_scale", | ||
72 | type=int, | ||
73 | default=100, | ||
74 | ) | 90 | ) |
75 | parser.add_argument( | 91 | parser.add_argument( |
76 | "--seed", | 92 | "--seed", |
@@ -78,21 +94,21 @@ def parse_args(): | |||
78 | default=torch.random.seed(), | 94 | default=torch.random.seed(), |
79 | ) | 95 | ) |
80 | parser.add_argument( | 96 | parser.add_argument( |
81 | "--output_dir", | ||
82 | type=str, | ||
83 | default="output/inference", | ||
84 | ) | ||
85 | parser.add_argument( | ||
86 | "--config", | 97 | "--config", |
87 | type=str, | 98 | type=str, |
88 | default=None, | 99 | default=None, |
89 | ) | 100 | ) |
90 | 101 | ||
91 | args = parser.parse_args() | 102 | return parser |
103 | |||
104 | |||
105 | def run_parser(parser, input=None): | ||
106 | args = parser.parse_known_args(input)[0] | ||
107 | |||
92 | if args.config is not None: | 108 | if args.config is not None: |
93 | with open(args.config, 'rt') as f: | 109 | with open(args.config, 'rt') as f: |
94 | args = parser.parse_args( | 110 | args = parser.parse_known_args( |
95 | namespace=argparse.Namespace(**json.load(f)["args"])) | 111 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] |
96 | 112 | ||
97 | return args | 113 | return args |
98 | 114 | ||
@@ -104,24 +120,24 @@ def save_args(basepath, args, extra={}): | |||
104 | json.dump(info, f, indent=4) | 120 | json.dump(info, f, indent=4) |
105 | 121 | ||
106 | 122 | ||
107 | def gen(args, output_dir): | 123 | def create_pipeline(model, scheduler, dtype=torch.bfloat16): |
108 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 124 | print("Loading Stable Diffusion pipeline...") |
109 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
110 | clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
111 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
112 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
113 | feature_extractor = CLIPFeatureExtractor.from_pretrained( | ||
114 | "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.bfloat16) | ||
115 | 125 | ||
116 | if args.scheduler == "plms": | 126 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
127 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | ||
128 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | ||
129 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | ||
130 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
131 | |||
132 | if scheduler == "plms": | ||
117 | scheduler = PNDMScheduler( | 133 | scheduler = PNDMScheduler( |
118 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 134 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
119 | ) | 135 | ) |
120 | elif args.scheduler == "klms": | 136 | elif scheduler == "klms": |
121 | scheduler = LMSDiscreteScheduler( | 137 | scheduler = LMSDiscreteScheduler( |
122 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 138 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
123 | ) | 139 | ) |
124 | elif args.scheduler == "ddim": | 140 | elif scheduler == "ddim": |
125 | scheduler = DDIMScheduler( | 141 | scheduler = DDIMScheduler( |
126 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 142 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False |
127 | ) | 143 | ) |
@@ -135,13 +151,24 @@ def gen(args, output_dir): | |||
135 | vae=vae, | 151 | vae=vae, |
136 | unet=unet, | 152 | unet=unet, |
137 | tokenizer=tokenizer, | 153 | tokenizer=tokenizer, |
138 | clip_model=clip_model, | ||
139 | scheduler=scheduler, | 154 | scheduler=scheduler, |
140 | feature_extractor=feature_extractor | 155 | feature_extractor=feature_extractor |
141 | ) | 156 | ) |
142 | pipeline.enable_attention_slicing() | 157 | pipeline.enable_attention_slicing() |
143 | pipeline.to("cuda") | 158 | pipeline.to("cuda") |
144 | 159 | ||
160 | print("Pipeline loaded.") | ||
161 | |||
162 | return pipeline | ||
163 | |||
164 | |||
165 | def generate(output_dir, pipeline, args): | ||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | ||
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | |||
170 | save_args(output_dir, args) | ||
171 | |||
145 | with autocast("cuda"): | 172 | with autocast("cuda"): |
146 | for i in range(args.batch_num): | 173 | for i in range(args.batch_num): |
147 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
@@ -152,7 +179,6 @@ def gen(args, output_dir): | |||
152 | negative_prompt=args.negative_prompt, | 179 | negative_prompt=args.negative_prompt, |
153 | num_inference_steps=args.steps, | 180 | num_inference_steps=args.steps, |
154 | guidance_scale=args.guidance_scale, | 181 | guidance_scale=args.guidance_scale, |
155 | clip_guidance_scale=args.clip_guidance_scale, | ||
156 | generator=generator, | 182 | generator=generator, |
157 | ).images | 183 | ).images |
158 | 184 | ||
@@ -160,18 +186,56 @@ def gen(args, output_dir): | |||
160 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
161 | 187 | ||
162 | 188 | ||
163 | def main(): | 189 | class CmdParse(cmd.Cmd): |
164 | args = parse_args() | 190 | prompt = 'dream> ' |
191 | commands = [] | ||
165 | 192 | ||
166 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 193 | def __init__(self, output_dir, pipeline, parser): |
167 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 194 | super().__init__() |
168 | output_dir.mkdir(parents=True, exist_ok=True) | ||
169 | 195 | ||
170 | save_args(output_dir, args) | 196 | self.output_dir = output_dir |
197 | self.pipeline = pipeline | ||
198 | self.parser = parser | ||
199 | |||
200 | def default(self, line): | ||
201 | line = line.replace("'", "\\'") | ||
202 | |||
203 | try: | ||
204 | elements = shlex.split(line) | ||
205 | except ValueError as e: | ||
206 | print(str(e)) | ||
207 | |||
208 | if elements[0] == 'q': | ||
209 | return True | ||
210 | |||
211 | try: | ||
212 | args = run_parser(self.parser, elements) | ||
213 | except SystemExit: | ||
214 | self.parser.print_help() | ||
215 | |||
216 | if len(args.prompt) == 0: | ||
217 | print('Try again with a prompt!') | ||
218 | |||
219 | try: | ||
220 | generate(self.output_dir, self.pipeline, args) | ||
221 | except KeyboardInterrupt: | ||
222 | print('Generation cancelled.') | ||
223 | |||
224 | def do_exit(self, line): | ||
225 | return True | ||
226 | |||
227 | |||
228 | def main(): | ||
229 | logging.basicConfig(stream=sys.stdout, level=logging.WARN) | ||
171 | 230 | ||
172 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 231 | args_parser = create_args_parser() |
232 | args = run_parser(args_parser) | ||
233 | output_dir = Path(args.output_dir) | ||
173 | 234 | ||
174 | gen(args, output_dir) | 235 | pipeline = create_pipeline(args.model, args.scheduler) |
236 | cmd_parser = create_cmd_parser() | ||
237 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | ||
238 | cmd_prompt.cmdloop() | ||
175 | 239 | ||
176 | 240 | ||
177 | if __name__ == "__main__": | 241 | if __name__ == "__main__": |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index 306d9a9..ddf7ce1 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
@@ -17,53 +17,14 @@ from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | |||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
18 | 18 | ||
19 | 19 | ||
20 | class MakeCutouts(nn.Module): | ||
21 | def __init__(self, cut_size, cut_power=1.0): | ||
22 | super().__init__() | ||
23 | |||
24 | self.cut_size = cut_size | ||
25 | self.cut_power = cut_power | ||
26 | |||
27 | def forward(self, pixel_values, num_cutouts): | ||
28 | sideY, sideX = pixel_values.shape[2:4] | ||
29 | max_size = min(sideX, sideY) | ||
30 | min_size = min(sideX, sideY, self.cut_size) | ||
31 | cutouts = [] | ||
32 | for _ in range(num_cutouts): | ||
33 | size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | ||
34 | offsetx = torch.randint(0, sideX - size + 1, ()) | ||
35 | offsety = torch.randint(0, sideY - size + 1, ()) | ||
36 | cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] | ||
37 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | ||
38 | return torch.cat(cutouts) | ||
39 | |||
40 | |||
41 | def spherical_dist_loss(x, y): | ||
42 | x = F.normalize(x, dim=-1) | ||
43 | y = F.normalize(y, dim=-1) | ||
44 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | ||
45 | |||
46 | |||
47 | def set_requires_grad(model, value): | ||
48 | for param in model.parameters(): | ||
49 | param.requires_grad = value | ||
50 | |||
51 | |||
52 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): |
53 | """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | ||
54 | - https://github.com/Jack000/glid-3-xl | ||
55 | - https://github.dev/crowsonkb/k-diffusion | ||
56 | """ | ||
57 | |||
58 | def __init__( | 21 | def __init__( |
59 | self, | 22 | self, |
60 | vae: AutoencoderKL, | 23 | vae: AutoencoderKL, |
61 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
62 | clip_model: CLIPModel, | ||
63 | tokenizer: CLIPTokenizer, | 25 | tokenizer: CLIPTokenizer, |
64 | unet: UNet2DConditionModel, | 26 | unet: UNet2DConditionModel, |
65 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | 27 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], |
66 | feature_extractor: CLIPFeatureExtractor, | ||
67 | **kwargs, | 28 | **kwargs, |
68 | ): | 29 | ): |
69 | super().__init__() | 30 | super().__init__() |
@@ -85,19 +46,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
85 | self.register_modules( | 46 | self.register_modules( |
86 | vae=vae, | 47 | vae=vae, |
87 | text_encoder=text_encoder, | 48 | text_encoder=text_encoder, |
88 | clip_model=clip_model, | ||
89 | tokenizer=tokenizer, | 49 | tokenizer=tokenizer, |
90 | unet=unet, | 50 | unet=unet, |
91 | scheduler=scheduler, | 51 | scheduler=scheduler, |
92 | feature_extractor=feature_extractor, | ||
93 | ) | 52 | ) |
94 | 53 | ||
95 | self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | ||
96 | self.make_cutouts = MakeCutouts(feature_extractor.size) | ||
97 | |||
98 | set_requires_grad(self.text_encoder, False) | ||
99 | set_requires_grad(self.clip_model, False) | ||
100 | |||
101 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 54 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
102 | r""" | 55 | r""" |
103 | Enable sliced attention computation. | 56 | Enable sliced attention computation. |
@@ -125,87 +78,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
125 | # set slice_size = `None` to disable `attention slicing` | 78 | # set slice_size = `None` to disable `attention slicing` |
126 | self.enable_attention_slicing(None) | 79 | self.enable_attention_slicing(None) |
127 | 80 | ||
128 | def freeze_vae(self): | ||
129 | set_requires_grad(self.vae, False) | ||
130 | |||
131 | def unfreeze_vae(self): | ||
132 | set_requires_grad(self.vae, True) | ||
133 | |||
134 | def freeze_unet(self): | ||
135 | set_requires_grad(self.unet, False) | ||
136 | |||
137 | def unfreeze_unet(self): | ||
138 | set_requires_grad(self.unet, True) | ||
139 | |||
140 | @torch.enable_grad() | ||
141 | def cond_fn( | ||
142 | self, | ||
143 | latents, | ||
144 | timestep, | ||
145 | index, | ||
146 | text_embeddings, | ||
147 | noise_pred_original, | ||
148 | text_embeddings_clip, | ||
149 | clip_guidance_scale, | ||
150 | num_cutouts, | ||
151 | use_cutouts=True, | ||
152 | ): | ||
153 | latents = latents.detach().requires_grad_() | ||
154 | |||
155 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
156 | sigma = self.scheduler.sigmas[index] | ||
157 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
158 | latent_model_input = latents / ((sigma**2 + 1) ** 0.5) | ||
159 | else: | ||
160 | latent_model_input = latents | ||
161 | |||
162 | # predict the noise residual | ||
163 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | ||
164 | |||
165 | if isinstance(self.scheduler, PNDMScheduler): | ||
166 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
167 | beta_prod_t = 1 - alpha_prod_t | ||
168 | # compute predicted original sample from predicted noise also called | ||
169 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
170 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | ||
171 | |||
172 | fac = torch.sqrt(beta_prod_t) | ||
173 | sample = pred_original_sample * (fac) + latents * (1 - fac) | ||
174 | elif isinstance(self.scheduler, LMSDiscreteScheduler): | ||
175 | sigma = self.scheduler.sigmas[index] | ||
176 | sample = latents - sigma * noise_pred | ||
177 | else: | ||
178 | raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | ||
179 | |||
180 | sample = 1 / 0.18215 * sample | ||
181 | image = self.vae.decode(sample).sample | ||
182 | image = (image / 2 + 0.5).clamp(0, 1) | ||
183 | |||
184 | if use_cutouts: | ||
185 | image = self.make_cutouts(image, num_cutouts) | ||
186 | else: | ||
187 | image = transforms.Resize(self.feature_extractor.size)(image) | ||
188 | image = self.normalize(image) | ||
189 | |||
190 | image_embeddings_clip = self.clip_model.get_image_features(image).float() | ||
191 | image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
192 | |||
193 | if use_cutouts: | ||
194 | dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | ||
195 | dists = dists.view([num_cutouts, sample.shape[0], -1]) | ||
196 | loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | ||
197 | else: | ||
198 | loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | ||
199 | |||
200 | grads = -torch.autograd.grad(loss, latents)[0] | ||
201 | |||
202 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
203 | latents = latents.detach() + grads * (sigma**2) | ||
204 | noise_pred = noise_pred_original | ||
205 | else: | ||
206 | noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | ||
207 | return noise_pred, latents | ||
208 | |||
209 | @torch.no_grad() | 81 | @torch.no_grad() |
210 | def __call__( | 82 | def __call__( |
211 | self, | 83 | self, |
@@ -216,10 +88,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
216 | num_inference_steps: Optional[int] = 50, | 88 | num_inference_steps: Optional[int] = 50, |
217 | guidance_scale: Optional[float] = 7.5, | 89 | guidance_scale: Optional[float] = 7.5, |
218 | eta: Optional[float] = 0.0, | 90 | eta: Optional[float] = 0.0, |
219 | clip_guidance_scale: Optional[float] = 100, | ||
220 | clip_prompt: Optional[Union[str, List[str]]] = None, | ||
221 | num_cutouts: Optional[int] = 4, | ||
222 | use_cutouts: Optional[bool] = True, | ||
223 | generator: Optional[torch.Generator] = None, | 91 | generator: Optional[torch.Generator] = None, |
224 | latents: Optional[torch.FloatTensor] = None, | 92 | latents: Optional[torch.FloatTensor] = None, |
225 | output_type: Optional[str] = "pil", | 93 | output_type: Optional[str] = "pil", |
@@ -305,24 +173,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
305 | "The following part of your input was truncated because CLIP can only handle sequences up to" | 173 | "The following part of your input was truncated because CLIP can only handle sequences up to" |
306 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 174 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
307 | ) | 175 | ) |
176 | print(f"Too many tokens: {removed_text}") | ||
308 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 177 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
309 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 178 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] |
310 | 179 | ||
311 | if clip_guidance_scale > 0: | ||
312 | if clip_prompt is not None: | ||
313 | clip_text_inputs = self.tokenizer( | ||
314 | clip_prompt, | ||
315 | padding="max_length", | ||
316 | max_length=self.tokenizer.model_max_length, | ||
317 | truncation=True, | ||
318 | return_tensors="pt", | ||
319 | ) | ||
320 | clip_text_input_ids = clip_text_inputs.input_ids | ||
321 | else: | ||
322 | clip_text_input_ids = text_input_ids | ||
323 | text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) | ||
324 | text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
325 | |||
326 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 180 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
327 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 181 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
328 | # corresponds to doing no classifier free guidance. | 182 | # corresponds to doing no classifier free guidance. |
@@ -357,7 +211,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
357 | else: | 211 | else: |
358 | if latents.shape != latents_shape: | 212 | if latents.shape != latents_shape: |
359 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
360 | latents = latents.to(self.device) | 214 | latents = latents.to(self.device) |
361 | 215 | ||
362 | # set timesteps | 216 | # set timesteps |
363 | self.scheduler.set_timesteps(num_inference_steps) | 217 | self.scheduler.set_timesteps(num_inference_steps) |
@@ -414,23 +268,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
414 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 268 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
415 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 269 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
416 | 270 | ||
417 | # perform clip guidance | ||
418 | if clip_guidance_scale > 0: | ||
419 | text_embeddings_for_guidance = ( | ||
420 | text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | ||
421 | ) | ||
422 | noise_pred, latents = self.cond_fn( | ||
423 | latents, | ||
424 | t, | ||
425 | i, | ||
426 | text_embeddings_for_guidance, | ||
427 | noise_pred, | ||
428 | text_embeddings_clip, | ||
429 | clip_guidance_scale, | ||
430 | num_cutouts, | ||
431 | use_cutouts, | ||
432 | ) | ||
433 | |||
434 | # compute the previous noisy sample x_t -> x_t-1 | 271 | # compute the previous noisy sample x_t -> x_t-1 |
435 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
436 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 57a56de..29ebd07 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -216,7 +216,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
216 | 216 | ||
217 | self.num_inference_steps = num_inference_steps | 217 | self.num_inference_steps = num_inference_steps |
218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 218 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 219 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps - 1).to(device=device) |
220 | self.timesteps = self.sigmas | 220 | self.timesteps = self.sigmas |
221 | 221 | ||
222 | def add_noise_to_input( | 222 | def add_noise_to_input( |
@@ -272,11 +272,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
272 | """ | 272 | """ |
273 | latents = sample | 273 | latents = sample |
274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) | 274 | sigma_down, sigma_up = get_ancestral_step(timestep, timestep_prev) |
275 | |||
276 | # if callback is not None: | ||
277 | # callback({'x': latents, 'i': i, 'sigma': timestep, 'sigma_hat': timestep, 'denoised': model_output}) | ||
278 | d = to_d(latents, timestep, model_output) | 275 | d = to_d(latents, timestep, model_output) |
279 | # Euler method | ||
280 | dt = sigma_down - timestep | 276 | dt = sigma_down - timestep |
281 | latents = latents + d * dt | 277 | latents = latents + d * dt |
282 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, | 278 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, |