summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
blob: 78871db3e3252f961e0b16147aff143b143be3c0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import copy
from typing import NamedTuple, Union

import numpy as np

from transformers import CLIPTokenizer


class MultiCLIPTokenizerItem(NamedTuple):
    token: str
    placeholder_id: int
    multi_ids: list[int]


class MultiCLIPTokenizer(CLIPTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_map: dict[int, list[int]] = {}

    def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
        if isinstance(new_tokens, list):
            if isinstance(num_vectors, int):
                num_vectors = [num_vectors] * len(new_tokens)

            if len(num_vectors) != len(new_tokens):
                raise ValueError("Expected new_tokens and num_vectors to have the same len")

            return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]

        if isinstance(num_vectors, list):
            raise ValueError("Expected num_vectors to be int for single token")

        super().add_tokens(new_tokens)

        if num_vectors == 1:
            multi_token = [new_tokens]
        else:
            multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
            super().add_tokens(multi_token)

        meta_id = super().convert_tokens_to_ids(new_tokens)
        multi_ids = super().convert_tokens_to_ids(multi_token)

        self.token_map[meta_id] = multi_ids

        return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids)

    def encode(self, *args, vector_shuffle=True, **kwargs):
        ids = super().encode(*args, **kwargs)
        new_ids = []

        for id in ids:
            if id in self.token_map:
                tokens = self.token_map[id]

                if vector_shuffle:
                    tokens = copy.copy(tokens)
                    np.random.shuffle(tokens)

                new_ids = new_ids + self.token_map[id]
            else:
                new_ids.append(id)

        return new_ids