diff options
author | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-01 11:40:14 +0200 |
commit | 5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch) | |
tree | a3461a4f1a04fba52ec8fde8b7b07095c7422d85 /pipelines/stable_diffusion | |
parent | Added custom SD pipeline + euler_a scheduler (diff) | |
download | textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2 textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip |
Made inference script interactive
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 169 |
1 files changed, 3 insertions, 166 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index 306d9a9..ddf7ce1 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
@@ -17,53 +17,14 @@ from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | |||
17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
18 | 18 | ||
19 | 19 | ||
20 | class MakeCutouts(nn.Module): | ||
21 | def __init__(self, cut_size, cut_power=1.0): | ||
22 | super().__init__() | ||
23 | |||
24 | self.cut_size = cut_size | ||
25 | self.cut_power = cut_power | ||
26 | |||
27 | def forward(self, pixel_values, num_cutouts): | ||
28 | sideY, sideX = pixel_values.shape[2:4] | ||
29 | max_size = min(sideX, sideY) | ||
30 | min_size = min(sideX, sideY, self.cut_size) | ||
31 | cutouts = [] | ||
32 | for _ in range(num_cutouts): | ||
33 | size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | ||
34 | offsetx = torch.randint(0, sideX - size + 1, ()) | ||
35 | offsety = torch.randint(0, sideY - size + 1, ()) | ||
36 | cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] | ||
37 | cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | ||
38 | return torch.cat(cutouts) | ||
39 | |||
40 | |||
41 | def spherical_dist_loss(x, y): | ||
42 | x = F.normalize(x, dim=-1) | ||
43 | y = F.normalize(y, dim=-1) | ||
44 | return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | ||
45 | |||
46 | |||
47 | def set_requires_grad(model, value): | ||
48 | for param in model.parameters(): | ||
49 | param.requires_grad = value | ||
50 | |||
51 | |||
52 | class CLIPGuidedStableDiffusion(DiffusionPipeline): | 20 | class CLIPGuidedStableDiffusion(DiffusionPipeline): |
53 | """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | ||
54 | - https://github.com/Jack000/glid-3-xl | ||
55 | - https://github.dev/crowsonkb/k-diffusion | ||
56 | """ | ||
57 | |||
58 | def __init__( | 21 | def __init__( |
59 | self, | 22 | self, |
60 | vae: AutoencoderKL, | 23 | vae: AutoencoderKL, |
61 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
62 | clip_model: CLIPModel, | ||
63 | tokenizer: CLIPTokenizer, | 25 | tokenizer: CLIPTokenizer, |
64 | unet: UNet2DConditionModel, | 26 | unet: UNet2DConditionModel, |
65 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | 27 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler], |
66 | feature_extractor: CLIPFeatureExtractor, | ||
67 | **kwargs, | 28 | **kwargs, |
68 | ): | 29 | ): |
69 | super().__init__() | 30 | super().__init__() |
@@ -85,19 +46,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
85 | self.register_modules( | 46 | self.register_modules( |
86 | vae=vae, | 47 | vae=vae, |
87 | text_encoder=text_encoder, | 48 | text_encoder=text_encoder, |
88 | clip_model=clip_model, | ||
89 | tokenizer=tokenizer, | 49 | tokenizer=tokenizer, |
90 | unet=unet, | 50 | unet=unet, |
91 | scheduler=scheduler, | 51 | scheduler=scheduler, |
92 | feature_extractor=feature_extractor, | ||
93 | ) | 52 | ) |
94 | 53 | ||
95 | self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | ||
96 | self.make_cutouts = MakeCutouts(feature_extractor.size) | ||
97 | |||
98 | set_requires_grad(self.text_encoder, False) | ||
99 | set_requires_grad(self.clip_model, False) | ||
100 | |||
101 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 54 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
102 | r""" | 55 | r""" |
103 | Enable sliced attention computation. | 56 | Enable sliced attention computation. |
@@ -125,87 +78,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
125 | # set slice_size = `None` to disable `attention slicing` | 78 | # set slice_size = `None` to disable `attention slicing` |
126 | self.enable_attention_slicing(None) | 79 | self.enable_attention_slicing(None) |
127 | 80 | ||
128 | def freeze_vae(self): | ||
129 | set_requires_grad(self.vae, False) | ||
130 | |||
131 | def unfreeze_vae(self): | ||
132 | set_requires_grad(self.vae, True) | ||
133 | |||
134 | def freeze_unet(self): | ||
135 | set_requires_grad(self.unet, False) | ||
136 | |||
137 | def unfreeze_unet(self): | ||
138 | set_requires_grad(self.unet, True) | ||
139 | |||
140 | @torch.enable_grad() | ||
141 | def cond_fn( | ||
142 | self, | ||
143 | latents, | ||
144 | timestep, | ||
145 | index, | ||
146 | text_embeddings, | ||
147 | noise_pred_original, | ||
148 | text_embeddings_clip, | ||
149 | clip_guidance_scale, | ||
150 | num_cutouts, | ||
151 | use_cutouts=True, | ||
152 | ): | ||
153 | latents = latents.detach().requires_grad_() | ||
154 | |||
155 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
156 | sigma = self.scheduler.sigmas[index] | ||
157 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS | ||
158 | latent_model_input = latents / ((sigma**2 + 1) ** 0.5) | ||
159 | else: | ||
160 | latent_model_input = latents | ||
161 | |||
162 | # predict the noise residual | ||
163 | noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | ||
164 | |||
165 | if isinstance(self.scheduler, PNDMScheduler): | ||
166 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | ||
167 | beta_prod_t = 1 - alpha_prod_t | ||
168 | # compute predicted original sample from predicted noise also called | ||
169 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
170 | pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | ||
171 | |||
172 | fac = torch.sqrt(beta_prod_t) | ||
173 | sample = pred_original_sample * (fac) + latents * (1 - fac) | ||
174 | elif isinstance(self.scheduler, LMSDiscreteScheduler): | ||
175 | sigma = self.scheduler.sigmas[index] | ||
176 | sample = latents - sigma * noise_pred | ||
177 | else: | ||
178 | raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | ||
179 | |||
180 | sample = 1 / 0.18215 * sample | ||
181 | image = self.vae.decode(sample).sample | ||
182 | image = (image / 2 + 0.5).clamp(0, 1) | ||
183 | |||
184 | if use_cutouts: | ||
185 | image = self.make_cutouts(image, num_cutouts) | ||
186 | else: | ||
187 | image = transforms.Resize(self.feature_extractor.size)(image) | ||
188 | image = self.normalize(image) | ||
189 | |||
190 | image_embeddings_clip = self.clip_model.get_image_features(image).float() | ||
191 | image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
192 | |||
193 | if use_cutouts: | ||
194 | dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | ||
195 | dists = dists.view([num_cutouts, sample.shape[0], -1]) | ||
196 | loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | ||
197 | else: | ||
198 | loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | ||
199 | |||
200 | grads = -torch.autograd.grad(loss, latents)[0] | ||
201 | |||
202 | if isinstance(self.scheduler, LMSDiscreteScheduler): | ||
203 | latents = latents.detach() + grads * (sigma**2) | ||
204 | noise_pred = noise_pred_original | ||
205 | else: | ||
206 | noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | ||
207 | return noise_pred, latents | ||
208 | |||
209 | @torch.no_grad() | 81 | @torch.no_grad() |
210 | def __call__( | 82 | def __call__( |
211 | self, | 83 | self, |
@@ -216,10 +88,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
216 | num_inference_steps: Optional[int] = 50, | 88 | num_inference_steps: Optional[int] = 50, |
217 | guidance_scale: Optional[float] = 7.5, | 89 | guidance_scale: Optional[float] = 7.5, |
218 | eta: Optional[float] = 0.0, | 90 | eta: Optional[float] = 0.0, |
219 | clip_guidance_scale: Optional[float] = 100, | ||
220 | clip_prompt: Optional[Union[str, List[str]]] = None, | ||
221 | num_cutouts: Optional[int] = 4, | ||
222 | use_cutouts: Optional[bool] = True, | ||
223 | generator: Optional[torch.Generator] = None, | 91 | generator: Optional[torch.Generator] = None, |
224 | latents: Optional[torch.FloatTensor] = None, | 92 | latents: Optional[torch.FloatTensor] = None, |
225 | output_type: Optional[str] = "pil", | 93 | output_type: Optional[str] = "pil", |
@@ -305,24 +173,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
305 | "The following part of your input was truncated because CLIP can only handle sequences up to" | 173 | "The following part of your input was truncated because CLIP can only handle sequences up to" |
306 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 174 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
307 | ) | 175 | ) |
176 | print(f"Too many tokens: {removed_text}") | ||
308 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 177 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
309 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 178 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] |
310 | 179 | ||
311 | if clip_guidance_scale > 0: | ||
312 | if clip_prompt is not None: | ||
313 | clip_text_inputs = self.tokenizer( | ||
314 | clip_prompt, | ||
315 | padding="max_length", | ||
316 | max_length=self.tokenizer.model_max_length, | ||
317 | truncation=True, | ||
318 | return_tensors="pt", | ||
319 | ) | ||
320 | clip_text_input_ids = clip_text_inputs.input_ids | ||
321 | else: | ||
322 | clip_text_input_ids = text_input_ids | ||
323 | text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device)) | ||
324 | text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | ||
325 | |||
326 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 180 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
327 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 181 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
328 | # corresponds to doing no classifier free guidance. | 182 | # corresponds to doing no classifier free guidance. |
@@ -357,7 +211,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
357 | else: | 211 | else: |
358 | if latents.shape != latents_shape: | 212 | if latents.shape != latents_shape: |
359 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 213 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
360 | latents = latents.to(self.device) | 214 | latents = latents.to(self.device) |
361 | 215 | ||
362 | # set timesteps | 216 | # set timesteps |
363 | self.scheduler.set_timesteps(num_inference_steps) | 217 | self.scheduler.set_timesteps(num_inference_steps) |
@@ -414,23 +268,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
414 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 268 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
415 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 269 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
416 | 270 | ||
417 | # perform clip guidance | ||
418 | if clip_guidance_scale > 0: | ||
419 | text_embeddings_for_guidance = ( | ||
420 | text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | ||
421 | ) | ||
422 | noise_pred, latents = self.cond_fn( | ||
423 | latents, | ||
424 | t, | ||
425 | i, | ||
426 | text_embeddings_for_guidance, | ||
427 | noise_pred, | ||
428 | text_embeddings_clip, | ||
429 | clip_guidance_scale, | ||
430 | num_cutouts, | ||
431 | use_cutouts, | ||
432 | ) | ||
433 | |||
434 | # compute the previous noisy sample x_t -> x_t-1 | 271 | # compute the previous noisy sample x_t -> x_t-1 |
435 | if isinstance(self.scheduler, LMSDiscreteScheduler): | 272 | if isinstance(self.scheduler, LMSDiscreteScheduler): |
436 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample | 273 | latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample |