summaryrefslogtreecommitdiffstats
path: root/models/clip/tokenizer.py
blob: a3e6e70fb843178e59bfde0d6c4251e9f666d66c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import copy
from typing import NamedTuple, Union

import numpy as np

from transformers import CLIPTokenizer


class MultiCLIPTokenizerItem(NamedTuple):
    token: str
    ids: list[int]


class MultiCLIPTokenizer(CLIPTokenizer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.token_map: dict[int, list[int]] = {}

    def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
        if isinstance(new_tokens, list):
            if isinstance(num_vectors, int):
                num_vectors = [num_vectors] * len(new_tokens)

            if len(num_vectors) != len(new_tokens):
                raise ValueError("Expected new_tokens and num_vectors to have the same len")

            return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]

        if isinstance(num_vectors, list):
            raise ValueError("Expected num_vectors to be int for single token")

        if num_vectors < 1:
            raise ValueError("Expected num_vectors to be >= 1")

        multi_token = [new_tokens] + [f"{new_tokens}_{i}" for i in range(1, num_vectors)]

        super().add_tokens(multi_token)

        ids = super().convert_tokens_to_ids(multi_token)

        self.token_map[ids[0]] = ids

        return MultiCLIPTokenizerItem(new_tokens, ids)

    def expand_id(self, id: int, vector_shuffle=True):
        if id in self.token_map:
            tokens = self.token_map[id]

            if vector_shuffle and len(tokens) > 2:
                subtokens = tokens[1:-1]
                np.random.shuffle(subtokens)
                tokens = tokens[:1] + subtokens + tokens[-1:]

            return tokens
        else:
            return [id]

    def expand_ids(self, ids: list[int], vector_shuffle=True):
        return [
            new_id
            for id in ids
            for new_id in self.expand_id(id, vector_shuffle)
        ]

    def _call_one(self, text, *args, vector_shuffle=True, **kwargs):
        result = super()._call_one(text, *args, **kwargs)

        is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list)

        if is_batched:
            result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids]
        else:
            result.input_ids = self.expand_ids(result.input_ids, vector_shuffle)

        return result