diff options
Diffstat (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 52 |
1 files changed, 49 insertions, 3 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8927a78..1a84c8d 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -4,13 +4,14 @@ from typing import List, Optional, Union | |||
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
7 | import torch.optim as optim | ||
7 | import PIL | 8 | import PIL |
8 | 9 | ||
9 | from diffusers.configuration_utils import FrozenDict | 10 | from diffusers.configuration_utils import FrozenDict |
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 13 | from diffusers.utils import logging |
13 | from transformers import CLIPTextModel, CLIPTokenizer | 14 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel |
14 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
15 | 16 | ||
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
50 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
51 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
52 | 53 | ||
54 | self.aesthetic_gradient_embeddings = {} | ||
55 | self.aesthetic_gradient_lr = 1e-4 | ||
56 | self.aesthetic_gradient_iters = 10 | ||
57 | |||
53 | self.register_modules( | 58 | self.register_modules( |
54 | vae=vae, | 59 | vae=vae, |
55 | text_encoder=text_encoder, | 60 | text_encoder=text_encoder, |
@@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
58 | scheduler=scheduler, | 63 | scheduler=scheduler, |
59 | ) | 64 | ) |
60 | 65 | ||
66 | def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): | ||
67 | self.aesthetic_gradient_embeddings[keyword] = tensor | ||
68 | |||
69 | def get_text_embeddings(self, prompt, text_input_ids): | ||
70 | prompt = " ".join(prompt) | ||
71 | |||
72 | embeddings = [ | ||
73 | embedding | ||
74 | for key, embedding in self.aesthetic_gradient_embeddings.items() | ||
75 | if key in prompt | ||
76 | ] | ||
77 | |||
78 | if len(embeddings) != 0: | ||
79 | with torch.enable_grad(): | ||
80 | full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
81 | full_clip_model.to(self.device) | ||
82 | full_clip_model.text_model.train() | ||
83 | |||
84 | optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) | ||
85 | |||
86 | for embs in embeddings: | ||
87 | embs = embs.clone().detach().to(self.device) | ||
88 | embs /= embs.norm(dim=-1, keepdim=True) | ||
89 | |||
90 | for i in range(self.aesthetic_gradient_iters): | ||
91 | text_embs = full_clip_model.get_text_features(text_input_ids) | ||
92 | text_embs /= text_embs.norm(dim=-1, keepdim=True) | ||
93 | sim = text_embs @ embs.T | ||
94 | loss = -sim | ||
95 | loss = loss.mean() | ||
96 | |||
97 | loss.backward() | ||
98 | optimizer.step() | ||
99 | optimizer.zero_grad() | ||
100 | |||
101 | full_clip_model.text_model.eval() | ||
102 | |||
103 | return full_clip_model.text_model(text_input_ids)[0] | ||
104 | else: | ||
105 | return self.text_encoder(text_input_ids)[0] | ||
106 | |||
61 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 107 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
62 | r""" | 108 | r""" |
63 | Enable sliced attention computation. | 109 | Enable sliced attention computation. |
@@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
195 | ) | 241 | ) |
196 | print(f"Too many tokens: {removed_text}") | 242 | print(f"Too many tokens: {removed_text}") |
197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 243 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 244 | text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) |
199 | 245 | ||
200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 246 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 247 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
207 | uncond_input = self.tokenizer( | 253 | uncond_input = self.tokenizer( |
208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | 254 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" |
209 | ) | 255 | ) |
210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 256 | uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) |
211 | 257 | ||
212 | # For classifier free guidance, we need to do two forward passes. | 258 | # For classifier free guidance, we need to do two forward passes. |
213 | # Here we concatenate the unconditional and text embeddings into a single batch | 259 | # Here we concatenate the unconditional and text embeddings into a single batch |