From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'pipelines') diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -22,7 +22,7 @@ from diffusers import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer -from models.clip.prompt import PromptProcessor +from models.clip.util import unify_input_ids, get_extended_embeddings logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.prompt_processor = PromptProcessor(tokenizer, text_encoder) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance: bool, device ): - text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt + if isinstance(prompt[0], str): + text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids + else: + text_input_ids = prompt + text_input_ids *= num_images_per_prompt if do_classifier_free_guidance: - unconditional_input_ids = self.prompt_processor.get_input_ids( - negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt + if isinstance(prompt[0], str): + unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids + else: + unconditional_input_ids = negative_prompt unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids - text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) + text_inputs = unify_input_ids(self.tokenizer, text_input_ids) text_input_ids = text_inputs.input_ids if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: @@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): else: attention_mask = None - text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) + text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) return text_embeddings -- cgit v1.2.3-54-g00ecf