diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
commit | 7b149930bb53b93db74106ad20a30abf4b114f9b (patch) | |
tree | 67c2ccbce2a9838ad8a020ee527b19113e67e30a /pipelines/stable_diffusion | |
parent | Added TI decay start offset (diff) | |
download | textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2 textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip |
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -22,7 +22,7 @@ from diffusers import ( | |||
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.utils import logging, randn_tensor | 23 | from diffusers.utils import logging, randn_tensor |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.util import unify_input_ids, get_extended_embeddings |
26 | 26 | ||
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
28 | 28 | ||
@@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
70 | new_config["steps_offset"] = 1 | 70 | new_config["steps_offset"] = 1 |
71 | scheduler._internal_dict = FrozenDict(new_config) | 71 | scheduler._internal_dict = FrozenDict(new_config) |
72 | 72 | ||
73 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
74 | |||
75 | self.register_modules( | 73 | self.register_modules( |
76 | vae=vae, | 74 | vae=vae, |
77 | text_encoder=text_encoder, | 75 | text_encoder=text_encoder, |
@@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
213 | do_classifier_free_guidance: bool, | 211 | do_classifier_free_guidance: bool, |
214 | device | 212 | device |
215 | ): | 213 | ): |
216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | 214 | if isinstance(prompt[0], str): |
215 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | ||
216 | else: | ||
217 | text_input_ids = prompt | ||
218 | |||
217 | text_input_ids *= num_images_per_prompt | 219 | text_input_ids *= num_images_per_prompt |
218 | 220 | ||
219 | if do_classifier_free_guidance: | 221 | if do_classifier_free_guidance: |
220 | unconditional_input_ids = self.prompt_processor.get_input_ids( | 222 | if isinstance(prompt[0], str): |
221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | 223 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids |
224 | else: | ||
225 | unconditional_input_ids = negative_prompt | ||
222 | unconditional_input_ids *= num_images_per_prompt | 226 | unconditional_input_ids *= num_images_per_prompt |
223 | text_input_ids = unconditional_input_ids + text_input_ids | 227 | text_input_ids = unconditional_input_ids + text_input_ids |
224 | 228 | ||
225 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) | 229 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
226 | text_input_ids = text_inputs.input_ids | 230 | text_input_ids = text_inputs.input_ids |
227 | 231 | ||
228 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 232 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
@@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
230 | else: | 234 | else: |
231 | attention_mask = None | 235 | attention_mask = None |
232 | 236 | ||
233 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | 237 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
234 | 238 | ||
235 | return text_embeddings | 239 | return text_embeddings |
236 | 240 | ||