summaryrefslogtreecommitdiffstats
path: root/pipelines
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /pipelines
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'pipelines')
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 6bc40e9..a5cfc60 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -22,7 +22,7 @@ from diffusers import (
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.utils import logging, randn_tensor 23from diffusers.utils import logging, randn_tensor
24from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 25from models.clip.util import unify_input_ids, get_extended_embeddings
26 26
27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 28
@@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
70 new_config["steps_offset"] = 1 70 new_config["steps_offset"] = 1
71 scheduler._internal_dict = FrozenDict(new_config) 71 scheduler._internal_dict = FrozenDict(new_config)
72 72
73 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
74
75 self.register_modules( 73 self.register_modules(
76 vae=vae, 74 vae=vae,
77 text_encoder=text_encoder, 75 text_encoder=text_encoder,
@@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline):
213 do_classifier_free_guidance: bool, 211 do_classifier_free_guidance: bool,
214 device 212 device
215 ): 213 ):
216 text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt 214 if isinstance(prompt[0], str):
215 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids
216 else:
217 text_input_ids = prompt
218
217 text_input_ids *= num_images_per_prompt 219 text_input_ids *= num_images_per_prompt
218 220
219 if do_classifier_free_guidance: 221 if do_classifier_free_guidance:
220 unconditional_input_ids = self.prompt_processor.get_input_ids( 222 if isinstance(prompt[0], str):
221 negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt 223 unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids
224 else:
225 unconditional_input_ids = negative_prompt
222 unconditional_input_ids *= num_images_per_prompt 226 unconditional_input_ids *= num_images_per_prompt
223 text_input_ids = unconditional_input_ids + text_input_ids 227 text_input_ids = unconditional_input_ids + text_input_ids
224 228
225 text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) 229 text_inputs = unify_input_ids(self.tokenizer, text_input_ids)
226 text_input_ids = text_inputs.input_ids 230 text_input_ids = text_inputs.input_ids
227 231
228 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 232 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
@@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
230 else: 234 else:
231 attention_mask = None 235 attention_mask = None
232 236
233 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) 237 text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask)
234 238
235 return text_embeddings 239 return text_embeddings
236 240