From 6a49074dce78615bce54777fb2be3bfd0dd8f780 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 14 Oct 2022 20:03:01 +0200 Subject: Removed aesthetic gradients; training improvements --- .../stable_diffusion/vlpn_stable_diffusion.py | 50 ++-------------------- 1 file changed, 4 insertions(+), 46 deletions(-) (limited to 'pipelines/stable_diffusion/vlpn_stable_diffusion.py') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 1a84c8d..3e41f86 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -51,10 +51,6 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.aesthetic_gradient_embeddings = {} - self.aesthetic_gradient_lr = 1e-4 - self.aesthetic_gradient_iters = 10 - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -63,46 +59,8 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) - def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): - self.aesthetic_gradient_embeddings[keyword] = tensor - - def get_text_embeddings(self, prompt, text_input_ids): - prompt = " ".join(prompt) - - embeddings = [ - embedding - for key, embedding in self.aesthetic_gradient_embeddings.items() - if key in prompt - ] - - if len(embeddings) != 0: - with torch.enable_grad(): - full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - full_clip_model.to(self.device) - full_clip_model.text_model.train() - - optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) - - for embs in embeddings: - embs = embs.clone().detach().to(self.device) - embs /= embs.norm(dim=-1, keepdim=True) - - for i in range(self.aesthetic_gradient_iters): - text_embs = full_clip_model.get_text_features(text_input_ids) - text_embs /= text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ embs.T - loss = -sim - loss = loss.mean() - - loss.backward() - optimizer.step() - optimizer.zero_grad() - - full_clip_model.text_model.eval() - - return full_clip_model.text_model(text_input_ids)[0] - else: - return self.text_encoder(text_input_ids)[0] + def get_text_embeddings(self, text_input_ids): + return self.text_encoder(text_input_ids)[0] def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -241,7 +199,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) print(f"Too many tokens: {removed_text}") text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) + text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -253,7 +211,7 @@ class VlpnStableDiffusion(DiffusionPipeline): uncond_input = self.tokenizer( negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) + uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch -- cgit v1.2.3-54-g00ecf