1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
|
from typing import Optional
import torch
import torch.nn as nn
class PseudoSparseEmbedding(nn.Module):
def __init__(self, embedding_dim: int, device=None, dtype=torch.float32):
super().__init__()
self.embedding_dim = embedding_dim
self.dtype = dtype
self.params = nn.ParameterList()
self.mapping = torch.zeros(0, device=device, dtype=torch.long)
def forward(self, input_ids: torch.LongTensor):
ids = self.mapping[input_ids.to(self.mapping.device)]
mask = ~(ids == -1)
if torch.all(~mask):
embs = None
else:
embs = torch.stack([self.params[id] for id in ids[mask]])
return embs, mask
def resize(self, new_num_embeddings: int):
old_num_embeddings = self.mapping.shape[0]
n = min(old_num_embeddings, new_num_embeddings)
new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
new_mapping[:n] = self.mapping[:n]
self.mapping = new_mapping
def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
if len(input_ids.shape) != 0:
if tensor is not None:
return [self.set(id, t) for id, t in zip(input_ids, tensor)]
else:
return [self.set(id) for id in input_ids]
if tensor is None:
tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
if tensor.shape[-1] != self.embedding_dim:
raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
id = self.mapping[input_ids]
if id == -1:
id = len(self.params)
self.mapping[input_ids] = id
self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
self.params[id] = tensor
def unset(self, input_ids: torch.LongTensor):
self.mapping[input_ids] = -1
|