From 99b4dba56e3e1e434820d1221d561e90f1a6d30a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Apr 2023 13:11:11 +0200 Subject: TI via LoRA --- models/sparse.py | 66 -------------------------------------------------------- 1 file changed, 66 deletions(-) delete mode 100644 models/sparse.py (limited to 'models/sparse.py') diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - - -class PseudoSparseEmbedding(nn.Module): - def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): - super().__init__() - - self.embedding_dim = embedding_dim - self.dtype = dtype - self.params = nn.ParameterList() - - if dropout_p > 0.0: - self.dropout = nn.Dropout(p=dropout_p) - else: - self.dropout = nn.Identity() - - self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) - - def forward(self, input_ids: torch.LongTensor): - input_ids = input_ids.to(self.mapping.device) - ids = self.mapping[input_ids] - mask = ~(ids == -1) - - if torch.all(~mask): - embs = None - else: - embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) - - return embs, mask - - def resize(self, new_num_embeddings: int): - old_num_embeddings = self.mapping.shape[0] - n = min(old_num_embeddings, new_num_embeddings) - - new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 - new_mapping[:n] = self.mapping[:n] - - self.mapping = new_mapping - - def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): - if len(input_ids.shape) != 0: - if tensor is not None: - return [self.set(id, t) for id, t in zip(input_ids, tensor)] - else: - return [self.set(id) for id in input_ids] - - if tensor is None: - tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) - - if tensor.shape[-1] != self.embedding_dim: - raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") - - id = self.mapping[input_ids] - - if id == -1: - id = len(self.params) - self.mapping[input_ids] = id - self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) - - self.params[id] = tensor - - def unset(self, input_ids: torch.LongTensor): - self.mapping[input_ids] = -1 -- cgit v1.2.3-54-g00ecf