From 515f0f1fdc9a76bf63bd746c291dcfec7fc747fb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Oct 2022 21:11:53 +0200 Subject: Added support for Aesthetic Gradients --- .../stable_diffusion/vlpn_stable_diffusion.py | 52 ++++++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8927a78..1a84c8d 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -4,13 +4,14 @@ from typing import List, Optional, Union import numpy as np import torch +import torch.optim as optim import PIL from diffusers.configuration_utils import FrozenDict from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel from schedulers.scheduling_euler_a import EulerAScheduler logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.aesthetic_gradient_embeddings = {} + self.aesthetic_gradient_lr = 1e-4 + self.aesthetic_gradient_iters = 10 + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) + def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): + self.aesthetic_gradient_embeddings[keyword] = tensor + + def get_text_embeddings(self, prompt, text_input_ids): + prompt = " ".join(prompt) + + embeddings = [ + embedding + for key, embedding in self.aesthetic_gradient_embeddings.items() + if key in prompt + ] + + if len(embeddings) != 0: + with torch.enable_grad(): + full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + full_clip_model.to(self.device) + full_clip_model.text_model.train() + + optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) + + for embs in embeddings: + embs = embs.clone().detach().to(self.device) + embs /= embs.norm(dim=-1, keepdim=True) + + for i in range(self.aesthetic_gradient_iters): + text_embs = full_clip_model.get_text_features(text_input_ids) + text_embs /= text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ embs.T + loss = -sim + loss = loss.mean() + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + full_clip_model.text_model.eval() + + return full_clip_model.text_model(text_input_ids)[0] + else: + return self.text_encoder(text_input_ids)[0] + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) print(f"Too many tokens: {removed_text}") text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline): uncond_input = self.tokenizer( negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch -- cgit v1.2.3-54-g00ecf