summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
committerVolpeon <git@volpeon.ink>2022-10-01 11:40:14 +0200
commit5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688 (patch)
treea3461a4f1a04fba52ec8fde8b7b07095c7422d85 /pipelines/stable_diffusion
parentAdded custom SD pipeline + euler_a scheduler (diff)
downloadtextual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.gz
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.tar.bz2
textual-inversion-diff-5b3eb3b24c2ed33911a7c50b5b1e0f729b86c688.zip
Made inference script interactive
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py169
1 files changed, 3 insertions, 166 deletions
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
index 306d9a9..ddf7ce1 100644
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
@@ -17,53 +17,14 @@ from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward
17logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18 18
19 19
20class MakeCutouts(nn.Module):
21 def __init__(self, cut_size, cut_power=1.0):
22 super().__init__()
23
24 self.cut_size = cut_size
25 self.cut_power = cut_power
26
27 def forward(self, pixel_values, num_cutouts):
28 sideY, sideX = pixel_values.shape[2:4]
29 max_size = min(sideX, sideY)
30 min_size = min(sideX, sideY, self.cut_size)
31 cutouts = []
32 for _ in range(num_cutouts):
33 size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
34 offsetx = torch.randint(0, sideX - size + 1, ())
35 offsety = torch.randint(0, sideY - size + 1, ())
36 cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size]
37 cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
38 return torch.cat(cutouts)
39
40
41def spherical_dist_loss(x, y):
42 x = F.normalize(x, dim=-1)
43 y = F.normalize(y, dim=-1)
44 return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
45
46
47def set_requires_grad(model, value):
48 for param in model.parameters():
49 param.requires_grad = value
50
51
52class CLIPGuidedStableDiffusion(DiffusionPipeline): 20class CLIPGuidedStableDiffusion(DiffusionPipeline):
53 """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
54 - https://github.com/Jack000/glid-3-xl
55 - https://github.dev/crowsonkb/k-diffusion
56 """
57
58 def __init__( 21 def __init__(
59 self, 22 self,
60 vae: AutoencoderKL, 23 vae: AutoencoderKL,
61 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
62 clip_model: CLIPModel,
63 tokenizer: CLIPTokenizer, 25 tokenizer: CLIPTokenizer,
64 unet: UNet2DConditionModel, 26 unet: UNet2DConditionModel,
65 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], 27 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAScheduler],
66 feature_extractor: CLIPFeatureExtractor,
67 **kwargs, 28 **kwargs,
68 ): 29 ):
69 super().__init__() 30 super().__init__()
@@ -85,19 +46,11 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
85 self.register_modules( 46 self.register_modules(
86 vae=vae, 47 vae=vae,
87 text_encoder=text_encoder, 48 text_encoder=text_encoder,
88 clip_model=clip_model,
89 tokenizer=tokenizer, 49 tokenizer=tokenizer,
90 unet=unet, 50 unet=unet,
91 scheduler=scheduler, 51 scheduler=scheduler,
92 feature_extractor=feature_extractor,
93 ) 52 )
94 53
95 self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
96 self.make_cutouts = MakeCutouts(feature_extractor.size)
97
98 set_requires_grad(self.text_encoder, False)
99 set_requires_grad(self.clip_model, False)
100
101 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 54 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
102 r""" 55 r"""
103 Enable sliced attention computation. 56 Enable sliced attention computation.
@@ -125,87 +78,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
125 # set slice_size = `None` to disable `attention slicing` 78 # set slice_size = `None` to disable `attention slicing`
126 self.enable_attention_slicing(None) 79 self.enable_attention_slicing(None)
127 80
128 def freeze_vae(self):
129 set_requires_grad(self.vae, False)
130
131 def unfreeze_vae(self):
132 set_requires_grad(self.vae, True)
133
134 def freeze_unet(self):
135 set_requires_grad(self.unet, False)
136
137 def unfreeze_unet(self):
138 set_requires_grad(self.unet, True)
139
140 @torch.enable_grad()
141 def cond_fn(
142 self,
143 latents,
144 timestep,
145 index,
146 text_embeddings,
147 noise_pred_original,
148 text_embeddings_clip,
149 clip_guidance_scale,
150 num_cutouts,
151 use_cutouts=True,
152 ):
153 latents = latents.detach().requires_grad_()
154
155 if isinstance(self.scheduler, LMSDiscreteScheduler):
156 sigma = self.scheduler.sigmas[index]
157 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
158 latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
159 else:
160 latent_model_input = latents
161
162 # predict the noise residual
163 noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
164
165 if isinstance(self.scheduler, PNDMScheduler):
166 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
167 beta_prod_t = 1 - alpha_prod_t
168 # compute predicted original sample from predicted noise also called
169 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
170 pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
171
172 fac = torch.sqrt(beta_prod_t)
173 sample = pred_original_sample * (fac) + latents * (1 - fac)
174 elif isinstance(self.scheduler, LMSDiscreteScheduler):
175 sigma = self.scheduler.sigmas[index]
176 sample = latents - sigma * noise_pred
177 else:
178 raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
179
180 sample = 1 / 0.18215 * sample
181 image = self.vae.decode(sample).sample
182 image = (image / 2 + 0.5).clamp(0, 1)
183
184 if use_cutouts:
185 image = self.make_cutouts(image, num_cutouts)
186 else:
187 image = transforms.Resize(self.feature_extractor.size)(image)
188 image = self.normalize(image)
189
190 image_embeddings_clip = self.clip_model.get_image_features(image).float()
191 image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
192
193 if use_cutouts:
194 dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
195 dists = dists.view([num_cutouts, sample.shape[0], -1])
196 loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
197 else:
198 loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
199
200 grads = -torch.autograd.grad(loss, latents)[0]
201
202 if isinstance(self.scheduler, LMSDiscreteScheduler):
203 latents = latents.detach() + grads * (sigma**2)
204 noise_pred = noise_pred_original
205 else:
206 noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
207 return noise_pred, latents
208
209 @torch.no_grad() 81 @torch.no_grad()
210 def __call__( 82 def __call__(
211 self, 83 self,
@@ -216,10 +88,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
216 num_inference_steps: Optional[int] = 50, 88 num_inference_steps: Optional[int] = 50,
217 guidance_scale: Optional[float] = 7.5, 89 guidance_scale: Optional[float] = 7.5,
218 eta: Optional[float] = 0.0, 90 eta: Optional[float] = 0.0,
219 clip_guidance_scale: Optional[float] = 100,
220 clip_prompt: Optional[Union[str, List[str]]] = None,
221 num_cutouts: Optional[int] = 4,
222 use_cutouts: Optional[bool] = True,
223 generator: Optional[torch.Generator] = None, 91 generator: Optional[torch.Generator] = None,
224 latents: Optional[torch.FloatTensor] = None, 92 latents: Optional[torch.FloatTensor] = None,
225 output_type: Optional[str] = "pil", 93 output_type: Optional[str] = "pil",
@@ -305,24 +173,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
305 "The following part of your input was truncated because CLIP can only handle sequences up to" 173 "The following part of your input was truncated because CLIP can only handle sequences up to"
306 f" {self.tokenizer.model_max_length} tokens: {removed_text}" 174 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
307 ) 175 )
176 print(f"Too many tokens: {removed_text}")
308 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 177 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
309 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 178 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
310 179
311 if clip_guidance_scale > 0:
312 if clip_prompt is not None:
313 clip_text_inputs = self.tokenizer(
314 clip_prompt,
315 padding="max_length",
316 max_length=self.tokenizer.model_max_length,
317 truncation=True,
318 return_tensors="pt",
319 )
320 clip_text_input_ids = clip_text_inputs.input_ids
321 else:
322 clip_text_input_ids = text_input_ids
323 text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids.to(self.device))
324 text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
325
326 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 180 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
327 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 181 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
328 # corresponds to doing no classifier free guidance. 182 # corresponds to doing no classifier free guidance.
@@ -357,7 +211,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
357 else: 211 else:
358 if latents.shape != latents_shape: 212 if latents.shape != latents_shape:
359 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 213 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
360 latents = latents.to(self.device) 214 latents = latents.to(self.device)
361 215
362 # set timesteps 216 # set timesteps
363 self.scheduler.set_timesteps(num_inference_steps) 217 self.scheduler.set_timesteps(num_inference_steps)
@@ -414,23 +268,6 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
414 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 268 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
415 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 269 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
416 270
417 # perform clip guidance
418 if clip_guidance_scale > 0:
419 text_embeddings_for_guidance = (
420 text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
421 )
422 noise_pred, latents = self.cond_fn(
423 latents,
424 t,
425 i,
426 text_embeddings_for_guidance,
427 noise_pred,
428 text_embeddings_clip,
429 clip_guidance_scale,
430 num_cutouts,
431 use_cutouts,
432 )
433
434 # compute the previous noisy sample x_t -> x_t-1 271 # compute the previous noisy sample x_t -> x_t-1
435 if isinstance(self.scheduler, LMSDiscreteScheduler): 272 if isinstance(self.scheduler, LMSDiscreteScheduler):
436 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample 273 latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample