summaryrefslogtreecommitdiffstats
path: root/models/clip/util.py
blob: 7196bb681929dfaba38a8f7458aa222dc418651b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Optional

import torch

from transformers import CLIPTokenizer, CLIPTextModel


def unify_input_ids(
    tokenizer: CLIPTokenizer,
    input_ids: list[list[int]],
    max_length: Optional[int] = None,
):
    if max_length is None:
        return tokenizer.pad(
            {"input_ids": input_ids},
            padding=True,
            pad_to_multiple_of=tokenizer.model_max_length,
            return_tensors="pt",
        )
    else:
        return tokenizer.pad(
            {"input_ids": input_ids},
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )


def get_extended_embeddings(
    text_encoder: CLIPTextModel,
    input_ids: torch.LongTensor,
    position_ids: Optional[torch.LongTensor] = None,
    attention_mask=None,
):
    model_max_length = text_encoder.config.max_position_embeddings
    prompts = input_ids.shape[0]

    input_ids = input_ids.view((-1, model_max_length))
    if position_ids is not None:
        position_ids = position_ids.view((-1, model_max_length))
    if attention_mask is not None:
        attention_mask = attention_mask.view((-1, model_max_length))

    text_embeddings = text_encoder(
        input_ids, position_ids=position_ids, attention_mask=attention_mask
    )[0]
    text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
    return text_embeddings