summaryrefslogtreecommitdiffstats
path: root/pipelines/stable_diffusion/vlpn_stable_diffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py52
1 files changed, 49 insertions, 3 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8927a78..1a84c8d 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,13 +4,14 @@ from typing import List, Optional, Union
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
7import torch.optim as optim
7import PIL 8import PIL
8 9
9from diffusers.configuration_utils import FrozenDict 10from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 13from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 14from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel
14from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
15 16
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
50 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
51 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
52 53
54 self.aesthetic_gradient_embeddings = {}
55 self.aesthetic_gradient_lr = 1e-4
56 self.aesthetic_gradient_iters = 10
57
53 self.register_modules( 58 self.register_modules(
54 vae=vae, 59 vae=vae,
55 text_encoder=text_encoder, 60 text_encoder=text_encoder,
@@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline):
58 scheduler=scheduler, 63 scheduler=scheduler,
59 ) 64 )
60 65
66 def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor):
67 self.aesthetic_gradient_embeddings[keyword] = tensor
68
69 def get_text_embeddings(self, prompt, text_input_ids):
70 prompt = " ".join(prompt)
71
72 embeddings = [
73 embedding
74 for key, embedding in self.aesthetic_gradient_embeddings.items()
75 if key in prompt
76 ]
77
78 if len(embeddings) != 0:
79 with torch.enable_grad():
80 full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
81 full_clip_model.to(self.device)
82 full_clip_model.text_model.train()
83
84 optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr)
85
86 for embs in embeddings:
87 embs = embs.clone().detach().to(self.device)
88 embs /= embs.norm(dim=-1, keepdim=True)
89
90 for i in range(self.aesthetic_gradient_iters):
91 text_embs = full_clip_model.get_text_features(text_input_ids)
92 text_embs /= text_embs.norm(dim=-1, keepdim=True)
93 sim = text_embs @ embs.T
94 loss = -sim
95 loss = loss.mean()
96
97 loss.backward()
98 optimizer.step()
99 optimizer.zero_grad()
100
101 full_clip_model.text_model.eval()
102
103 return full_clip_model.text_model(text_input_ids)[0]
104 else:
105 return self.text_encoder(text_input_ids)[0]
106
61 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 107 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
62 r""" 108 r"""
63 Enable sliced attention computation. 109 Enable sliced attention computation.
@@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 ) 241 )
196 print(f"Too many tokens: {removed_text}") 242 print(f"Too many tokens: {removed_text}")
197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 243 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 244 text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device))
199 245
200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 246 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 247 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
207 uncond_input = self.tokenizer( 253 uncond_input = self.tokenizer(
208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 254 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
209 ) 255 )
210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 256 uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device))
211 257
212 # For classifier free guidance, we need to do two forward passes. 258 # For classifier free guidance, we need to do two forward passes.
213 # Here we concatenate the unconditional and text embeddings into a single batch 259 # Here we concatenate the unconditional and text embeddings into a single batch