diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
| commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
| tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /training | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip | |
Fixed accuracy calc, other improvements
Diffstat (limited to 'training')
| -rw-r--r-- | training/ti.py | 48 |
1 files changed, 0 insertions, 48 deletions
diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null | |||
| @@ -1,48 +0,0 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
| 7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
| 8 | |||
| 9 | |||
| 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
| 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) | ||
| 12 | text_encoder.text_model.embeddings = text_embeddings | ||
| 13 | |||
| 14 | |||
| 15 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
| 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): | ||
| 17 | super().__init__(config) | ||
| 18 | |||
| 19 | self.token_embedding = embeddings.token_embedding | ||
| 20 | self.position_embedding = embeddings.position_embedding | ||
| 21 | |||
| 22 | self.train_indices = torch.tensor(new_ids) | ||
| 23 | |||
| 24 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
| 25 | self.trainable_embedding.weight.data.zero_() | ||
| 26 | self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] | ||
| 27 | |||
| 28 | def forward( | ||
| 29 | self, | ||
| 30 | input_ids: Optional[torch.LongTensor] = None, | ||
| 31 | position_ids: Optional[torch.LongTensor] = None, | ||
| 32 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 33 | ) -> torch.Tensor: | ||
| 34 | device = input_ids.device | ||
| 35 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 36 | |||
| 37 | if position_ids is None: | ||
| 38 | position_ids = self.position_ids[:, :seq_length] | ||
| 39 | |||
| 40 | if inputs_embeds is None: | ||
| 41 | mask = torch.isin(input_ids, self.train_indices.to(device)) | ||
| 42 | inputs_embeds = self.token_embedding(input_ids) | ||
| 43 | inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] | ||
| 44 | |||
| 45 | position_embeddings = self.position_embedding(position_ids) | ||
| 46 | embeddings = inputs_embeds + position_embeddings | ||
| 47 | |||
| 48 | return embeddings | ||
