diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-16 12:59:08 +0200 |
| commit | 1aace3e44dae0489130039714f67d980628c92ec (patch) | |
| tree | 59a972b64bb3a3253e310055fc24381db68e8608 /models/clip | |
| parent | Patch xformers to cast dtypes (diff) | |
| download | textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.gz textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.tar.bz2 textual-inversion-diff-1aace3e44dae0489130039714f67d980628c92ec.zip | |
Avoid model recompilation due to varying prompt lengths
Diffstat (limited to 'models/clip')
| -rw-r--r-- | models/clip/util.py | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/models/clip/util.py b/models/clip/util.py index 883de6a..f94fbc7 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
| @@ -5,14 +5,21 @@ import torch | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): | 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): |
| 9 | return tokenizer.pad( | 9 | if max_length is None: |
| 10 | {"input_ids": input_ids}, | 10 | return tokenizer.pad( |
| 11 | padding=True, | 11 | {"input_ids": input_ids}, |
| 12 | pad_to_multiple_of=tokenizer.model_max_length, | 12 | padding=True, |
| 13 | return_tensors="pt" | 13 | pad_to_multiple_of=tokenizer.model_max_length, |
| 14 | ) | 14 | return_tensors="pt" |
| 15 | 15 | ) | |
| 16 | else: | ||
| 17 | return tokenizer.pad( | ||
| 18 | {"input_ids": input_ids}, | ||
| 19 | padding="max_length", | ||
| 20 | max_length=max_length, | ||
| 21 | return_tensors="pt" | ||
| 22 | ) | ||
| 16 | 23 | ||
| 17 | def get_extended_embeddings( | 24 | def get_extended_embeddings( |
| 18 | text_encoder: CLIPTextModel, | 25 | text_encoder: CLIPTextModel, |
