summaryrefslogtreecommitdiffstats
path: root/models/clip/prompt.py
blob: a7380be2d597156f19e38a1187f640f952fff468 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Union, Optional

import torch

from transformers import CLIPTokenizer, CLIPTextModel


class PromptProcessor():
    def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder

    def get_input_ids(self, prompt: Union[str, list[str]]):
        return self.tokenizer(
            prompt,
            padding="do_not_pad",
        ).input_ids

    def unify_input_ids(self, input_ids: list[list[int]]):
        return self.tokenizer.pad(
            {"input_ids": input_ids},
            padding=True,
            pad_to_multiple_of=self.tokenizer.model_max_length,
            return_tensors="pt"
        )

    def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None):
        prompts = input_ids.shape[0]

        input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
        if position_ids is not None:
            position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
        if attention_mask is not None:
            attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)

        text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0]
        text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
        return text_embeddings