diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 13:49:35 +0100 |
| commit | 7b149930bb53b93db74106ad20a30abf4b114f9b (patch) | |
| tree | 67c2ccbce2a9838ad8a020ee527b19113e67e30a /pipelines/stable_diffusion | |
| parent | Added TI decay start offset (diff) | |
| download | textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2 textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip | |
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'pipelines/stable_diffusion')
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 6bc40e9..a5cfc60 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -22,7 +22,7 @@ from diffusers import ( | |||
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.utils import logging, randn_tensor | 23 | from diffusers.utils import logging, randn_tensor |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.util import unify_input_ids, get_extended_embeddings |
| 26 | 26 | ||
| 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 28 | 28 | ||
| @@ -70,8 +70,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 70 | new_config["steps_offset"] = 1 | 70 | new_config["steps_offset"] = 1 |
| 71 | scheduler._internal_dict = FrozenDict(new_config) | 71 | scheduler._internal_dict = FrozenDict(new_config) |
| 72 | 72 | ||
| 73 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 74 | |||
| 75 | self.register_modules( | 73 | self.register_modules( |
| 76 | vae=vae, | 74 | vae=vae, |
| 77 | text_encoder=text_encoder, | 75 | text_encoder=text_encoder, |
| @@ -213,16 +211,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 213 | do_classifier_free_guidance: bool, | 211 | do_classifier_free_guidance: bool, |
| 214 | device | 212 | device |
| 215 | ): | 213 | ): |
| 216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | 214 | if isinstance(prompt[0], str): |
| 215 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | ||
| 216 | else: | ||
| 217 | text_input_ids = prompt | ||
| 218 | |||
| 217 | text_input_ids *= num_images_per_prompt | 219 | text_input_ids *= num_images_per_prompt |
| 218 | 220 | ||
| 219 | if do_classifier_free_guidance: | 221 | if do_classifier_free_guidance: |
| 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( | 222 | if isinstance(prompt[0], str): |
| 221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | 223 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids |
| 224 | else: | ||
| 225 | unconditional_input_ids = negative_prompt | ||
| 222 | unconditional_input_ids *= num_images_per_prompt | 226 | unconditional_input_ids *= num_images_per_prompt |
| 223 | text_input_ids = unconditional_input_ids + text_input_ids | 227 | text_input_ids = unconditional_input_ids + text_input_ids |
| 224 | 228 | ||
| 225 | text_inputs = self.prompt_processor.unify_input_ids(text_input_ids) | 229 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
| 226 | text_input_ids = text_inputs.input_ids | 230 | text_input_ids = text_inputs.input_ids |
| 227 | 231 | ||
| 228 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 232 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| @@ -230,7 +234,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 230 | else: | 234 | else: |
| 231 | attention_mask = None | 235 | attention_mask = None |
| 232 | 236 | ||
| 233 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids, attention_mask) | 237 | text_embeddings = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) |
| 234 | 238 | ||
| 235 | return text_embeddings | 239 | return text_embeddings |
| 236 | 240 | ||
