From 515f0f1fdc9a76bf63bd746c291dcfec7fc747fb Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Thu, 13 Oct 2022 21:11:53 +0200
Subject: Added support for Aesthetic Gradients

---
 .../stable_diffusion/vlpn_stable_diffusion.py      | 52 ++++++++++++++++++++--
 1 file changed, 49 insertions(+), 3 deletions(-)

(limited to 'pipelines')

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8927a78..1a84c8d 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,13 +4,14 @@ from typing import List, Optional, Union
 
 import numpy as np
 import torch
+import torch.optim as optim
 import PIL
 
 from diffusers.configuration_utils import FrozenDict
 from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
 from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
 from diffusers.utils import logging
-from transformers import CLIPTextModel, CLIPTokenizer
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel
 from schedulers.scheduling_euler_a import EulerAScheduler
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
@@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
             new_config["steps_offset"] = 1
             scheduler._internal_dict = FrozenDict(new_config)
 
+        self.aesthetic_gradient_embeddings = {}
+        self.aesthetic_gradient_lr = 1e-4
+        self.aesthetic_gradient_iters = 10
+
         self.register_modules(
             vae=vae,
             text_encoder=text_encoder,
@@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline):
             scheduler=scheduler,
         )
 
+    def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor):
+        self.aesthetic_gradient_embeddings[keyword] = tensor
+
+    def get_text_embeddings(self, prompt, text_input_ids):
+        prompt = " ".join(prompt)
+
+        embeddings = [
+            embedding
+            for key, embedding in self.aesthetic_gradient_embeddings.items()
+            if key in prompt
+        ]
+
+        if len(embeddings) != 0:
+            with torch.enable_grad():
+                full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
+                full_clip_model.to(self.device)
+                full_clip_model.text_model.train()
+
+                optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr)
+
+                for embs in embeddings:
+                    embs = embs.clone().detach().to(self.device)
+                    embs /= embs.norm(dim=-1, keepdim=True)
+
+                    for i in range(self.aesthetic_gradient_iters):
+                        text_embs = full_clip_model.get_text_features(text_input_ids)
+                        text_embs /= text_embs.norm(dim=-1, keepdim=True)
+                        sim = text_embs @ embs.T
+                        loss = -sim
+                        loss = loss.mean()
+
+                        loss.backward()
+                        optimizer.step()
+                        optimizer.zero_grad()
+
+            full_clip_model.text_model.eval()
+
+            return full_clip_model.text_model(text_input_ids)[0]
+        else:
+            return self.text_encoder(text_input_ids)[0]
+
     def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
         r"""
         Enable sliced attention computation.
@@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
             )
             print(f"Too many tokens: {removed_text}")
             text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
-        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+        text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device))
 
         # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
             uncond_input = self.tokenizer(
                 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
             )
-            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+            uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device))
 
             # For classifier free guidance, we need to do two forward passes.
             # Here we concatenate the unconditional and text embeddings into a single batch
-- 
cgit v1.2.3-70-g09d2