summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py50
1 files changed, 4 insertions, 46 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 1a84c8d..3e41f86 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -51,10 +51,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
51 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
52 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
53 53
54 self.aesthetic_gradient_embeddings = {}
55 self.aesthetic_gradient_lr = 1e-4
56 self.aesthetic_gradient_iters = 10
57
58 self.register_modules( 54 self.register_modules(
59 vae=vae, 55 vae=vae,
60 text_encoder=text_encoder, 56 text_encoder=text_encoder,
@@ -63,46 +59,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
63 scheduler=scheduler, 59 scheduler=scheduler,
64 ) 60 )
65 61
66 def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): 62 def get_text_embeddings(self, text_input_ids):
67 self.aesthetic_gradient_embeddings[keyword] = tensor 63 return self.text_encoder(text_input_ids)[0]
68
69 def get_text_embeddings(self, prompt, text_input_ids):
70 prompt = " ".join(prompt)
71
72 embeddings = [
73 embedding
74 for key, embedding in self.aesthetic_gradient_embeddings.items()
75 if key in prompt
76 ]
77
78 if len(embeddings) != 0:
79 with torch.enable_grad():
80 full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
81 full_clip_model.to(self.device)
82 full_clip_model.text_model.train()
83
84 optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr)
85
86 for embs in embeddings:
87 embs = embs.clone().detach().to(self.device)
88 embs /= embs.norm(dim=-1, keepdim=True)
89
90 for i in range(self.aesthetic_gradient_iters):
91 text_embs = full_clip_model.get_text_features(text_input_ids)
92 text_embs /= text_embs.norm(dim=-1, keepdim=True)
93 sim = text_embs @ embs.T
94 loss = -sim
95 loss = loss.mean()
96
97 loss.backward()
98 optimizer.step()
99 optimizer.zero_grad()
100
101 full_clip_model.text_model.eval()
102
103 return full_clip_model.text_model(text_input_ids)[0]
104 else:
105 return self.text_encoder(text_input_ids)[0]
106 64
107 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 65 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
108 r""" 66 r"""
@@ -241,7 +199,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
241 ) 199 )
242 print(f"Too many tokens: {removed_text}") 200 print(f"Too many tokens: {removed_text}")
243 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 201 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
244 text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) 202 text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device))
245 203
246 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 204 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
247 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 205 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -253,7 +211,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
253 uncond_input = self.tokenizer( 211 uncond_input = self.tokenizer(
254 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 212 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
255 ) 213 )
256 uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) 214 uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device))
257 215
258 # For classifier free guidance, we need to do two forward passes. 216 # For classifier free guidance, we need to do two forward passes.
259 # Here we concatenate the unconditional and text embeddings into a single batch 217 # Here we concatenate the unconditional and text embeddings into a single batch