diff options
author | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
commit | 99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch) | |
tree | 717a4099e9ebfedec702060fed5ed12aaceb0094 /models | |
parent | Added cycle LR decay (diff) | |
download | textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2 textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip |
TI via LoRA
Diffstat (limited to 'models')
-rw-r--r-- | models/clip/embeddings.py | 76 | ||||
-rw-r--r-- | models/lora.py | 131 | ||||
-rw-r--r-- | models/sparse.py | 66 |
3 files changed, 157 insertions, 116 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel | |||
11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | from models.sparse import PseudoSparseEmbedding | 14 | from models.lora import LoraEmbedding |
15 | |||
16 | |||
17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | ||
18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | ||
19 | |||
20 | if old_num_embeddings == new_num_embeddings: | ||
21 | return old_embedding | ||
22 | |||
23 | n = min(old_num_embeddings, new_num_embeddings) | ||
24 | |||
25 | new_embedding = nn.Embedding( | ||
26 | new_num_embeddings, | ||
27 | old_embedding_dim, | ||
28 | device=old_embedding.weight.device, | ||
29 | dtype=old_embedding.weight.dtype | ||
30 | ) | ||
31 | if initializer_factor is not None: | ||
32 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
33 | else: | ||
34 | nn.init.zeros_(new_embedding.weight.data) | ||
35 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | ||
36 | return new_embedding | ||
37 | 15 | ||
38 | 16 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 17 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): | 18 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): |
41 | super().__init__(config) | 19 | super().__init__(config) |
42 | 20 | ||
43 | self.token_embedding = embeddings.token_embedding | ||
44 | self.position_embedding = embeddings.position_embedding | 21 | self.position_embedding = embeddings.position_embedding |
45 | self.initializer_factor = config.initializer_factor | 22 | self.initializer_factor = config.initializer_factor |
46 | 23 | self.token_embedding = LoraEmbedding( | |
47 | self.token_override_embedding = PseudoSparseEmbedding( | 24 | self.token_embedding.num_embeddings, |
48 | self.token_embedding.embedding_dim, | 25 | self.token_embedding.embedding_dim, |
49 | dropout_p=dropout_p, | 26 | r, |
50 | device=self.token_embedding.weight.device, | 27 | lora_alpha, |
51 | dtype=self.token_embedding.weight.dtype, | 28 | lora_dropout, |
52 | ) | 29 | ) |
53 | 30 | ||
31 | self.token_embedding.weight = embeddings.token_embedding.weight | ||
32 | |||
54 | def resize(self, size: int): | 33 | def resize(self, size: int): |
55 | self.token_override_embedding.resize(size) | 34 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) |
56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | ||
57 | 35 | ||
58 | def add_embed( | 36 | def add_embed( |
59 | self, | 37 | self, |
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 66 | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids, initializer) | ||
91 | 68 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
98 | 75 | ||
99 | def persist(self): | 76 | def persist(self): |
100 | input_ids = torch.arange( | 77 | self.token_embedding.eval() |
101 | self.token_embedding.num_embeddings, | 78 | self.token_embedding.merged = False |
102 | device=self.token_override_embedding.mapping.device | ||
103 | ) | ||
104 | embs, mask = self.token_override_embedding(input_ids) | ||
105 | if embs is not None: | ||
106 | input_ids = input_ids[mask] | ||
107 | self.token_embedding.weight.data[input_ids] = embs | ||
108 | self.token_override_embedding.unset(input_ids) | ||
109 | 79 | ||
110 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
111 | if isinstance(input_ids, list): | 81 | if isinstance(input_ids, list): |
112 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 82 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
113 | 83 | ||
114 | embs = self.token_embedding(input_ids) | 84 | return self.token_embedding(input_ids) |
115 | embs_override, mask = self.token_override_embedding(input_ids) | ||
116 | if embs_override is not None: | ||
117 | embs[mask] = embs_override | ||
118 | |||
119 | return embs | ||
120 | 85 | ||
121 | def forward( | 86 | def forward( |
122 | self, | 87 | self, |
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
138 | return embeddings | 103 | return embeddings |
139 | 104 | ||
140 | 105 | ||
141 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: | 106 | def patch_managed_embeddings( |
142 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) | 107 | text_encoder: CLIPTextModel, |
108 | r: int = 8, | ||
109 | lora_alpha: int = 8, | ||
110 | lora_dropout: float = 0.0 | ||
111 | ) -> ManagedCLIPTextEmbeddings: | ||
112 | text_embeddings = ManagedCLIPTextEmbeddings( | ||
113 | text_encoder.config, | ||
114 | text_encoder.text_model.embeddings, | ||
115 | r, | ||
116 | lora_alpha, | ||
117 | lora_dropout | ||
118 | ) | ||
143 | text_encoder.text_model.embeddings = text_embeddings | 119 | text_encoder.text_model.embeddings = text_embeddings |
144 | return text_embeddings | 120 | return text_embeddings |
diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py | |||
@@ -0,0 +1,131 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | import torch.nn.functional as F | ||
6 | |||
7 | |||
8 | class LoraLayer(): | ||
9 | def __init__( | ||
10 | self, | ||
11 | r: int, | ||
12 | lora_alpha: int, | ||
13 | lora_dropout: float, | ||
14 | merge_weights: bool, | ||
15 | ): | ||
16 | self.r = r | ||
17 | self.lora_alpha = lora_alpha | ||
18 | self.lora_dropout_p = lora_dropout | ||
19 | |||
20 | if lora_dropout > 0.: | ||
21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
22 | else: | ||
23 | self.lora_dropout = nn.Identity() | ||
24 | |||
25 | self.merged = False | ||
26 | self.merge_weights = merge_weights | ||
27 | |||
28 | |||
29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
30 | def __init__( | ||
31 | self, | ||
32 | num_embeddings: int, | ||
33 | embedding_dim: int, | ||
34 | r: int = 0, | ||
35 | lora_alpha: int = 1, | ||
36 | lora_dropout: float = 0.0, | ||
37 | merge_weights: bool = True, | ||
38 | **kwargs | ||
39 | ): | ||
40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
41 | LoraLayer.__init__( | ||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
43 | ) | ||
44 | |||
45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | ||
46 | self.trainable_ids -= 1 | ||
47 | |||
48 | if r > 0: | ||
49 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) | ||
50 | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | ||
51 | self.scaling = self.lora_alpha / self.r | ||
52 | self.weight.requires_grad = False | ||
53 | |||
54 | self.reset_parameters() | ||
55 | |||
56 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
57 | n = min(self.num_embeddings, new_num_embeddings) | ||
58 | |||
59 | new_emb = LoraEmbedding( | ||
60 | new_num_embeddings, | ||
61 | self.embedding_dim, | ||
62 | self.r, | ||
63 | self.lora_alpha, | ||
64 | self.lora_dropout_p, | ||
65 | device=self.weight.device, | ||
66 | dtype=self.weight.dtype | ||
67 | ) | ||
68 | if initializer_factor is not None: | ||
69 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
70 | else: | ||
71 | nn.init.zeros_(new_emb.weight.data) | ||
72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
73 | new_emb.lora_A = self.lora_A | ||
74 | new_emb.lora_B = self.lora_B | ||
75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
76 | |||
77 | return new_emb | ||
78 | |||
79 | def mark_trainable(self, input_ids): | ||
80 | trainable_ids = self.trainable_ids[input_ids] | ||
81 | new_ids = trainable_ids[trainable_ids == -1] | ||
82 | |||
83 | if new_ids.shape[0] == 0: | ||
84 | return | ||
85 | |||
86 | n = self.trainable_ids.shape[0] | ||
87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | ||
88 | |||
89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | ||
90 | lora_A.data[:n] = self.lora_A.data | ||
91 | self.lora_A = lora_A | ||
92 | |||
93 | def reset_parameters(self): | ||
94 | nn.Embedding.reset_parameters(self) | ||
95 | if hasattr(self, 'lora_A'): | ||
96 | nn.init.zeros_(self.lora_A) | ||
97 | nn.init.normal_(self.lora_B) | ||
98 | |||
99 | def train(self, mode: bool = True): | ||
100 | nn.Embedding.train(self, mode) | ||
101 | if self.merge_weights and self.merged: | ||
102 | if self.r > 0: | ||
103 | mask = ~(self.trainable_ids == -1) | ||
104 | trainable_ids = self.trainable_ids[mask] | ||
105 | self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling | ||
106 | self.merged = False | ||
107 | |||
108 | def eval(self): | ||
109 | nn.Embedding.eval(self) | ||
110 | if self.merge_weights and not self.merged: | ||
111 | if self.r > 0: | ||
112 | mask = ~(self.trainable_ids == -1) | ||
113 | trainable_ids = self.trainable_ids[mask] | ||
114 | self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling | ||
115 | self.merged = True | ||
116 | |||
117 | def forward(self, input_ids: torch.Tensor): | ||
118 | result = nn.Embedding.forward(self, input_ids) | ||
119 | |||
120 | if self.r > 0 and not self.merged: | ||
121 | trainable_ids = self.trainable_ids[input_ids] | ||
122 | mask = ~(trainable_ids == -1) | ||
123 | trainable_ids = trainable_ids[mask] | ||
124 | |||
125 | after_A = F.embedding( | ||
126 | trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, | ||
127 | self.norm_type, self.scale_grad_by_freq, self.sparse | ||
128 | ) | ||
129 | result[mask] += (after_A @ self.lora_B.T) * self.scaling | ||
130 | |||
131 | return result | ||
diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null | |||
@@ -1,66 +0,0 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | |||
7 | class PseudoSparseEmbedding(nn.Module): | ||
8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): | ||
9 | super().__init__() | ||
10 | |||
11 | self.embedding_dim = embedding_dim | ||
12 | self.dtype = dtype | ||
13 | self.params = nn.ParameterList() | ||
14 | |||
15 | if dropout_p > 0.0: | ||
16 | self.dropout = nn.Dropout(p=dropout_p) | ||
17 | else: | ||
18 | self.dropout = nn.Identity() | ||
19 | |||
20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | ||
21 | |||
22 | def forward(self, input_ids: torch.LongTensor): | ||
23 | input_ids = input_ids.to(self.mapping.device) | ||
24 | ids = self.mapping[input_ids] | ||
25 | mask = ~(ids == -1) | ||
26 | |||
27 | if torch.all(~mask): | ||
28 | embs = None | ||
29 | else: | ||
30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) | ||
31 | |||
32 | return embs, mask | ||
33 | |||
34 | def resize(self, new_num_embeddings: int): | ||
35 | old_num_embeddings = self.mapping.shape[0] | ||
36 | n = min(old_num_embeddings, new_num_embeddings) | ||
37 | |||
38 | new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 | ||
39 | new_mapping[:n] = self.mapping[:n] | ||
40 | |||
41 | self.mapping = new_mapping | ||
42 | |||
43 | def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): | ||
44 | if len(input_ids.shape) != 0: | ||
45 | if tensor is not None: | ||
46 | return [self.set(id, t) for id, t in zip(input_ids, tensor)] | ||
47 | else: | ||
48 | return [self.set(id) for id in input_ids] | ||
49 | |||
50 | if tensor is None: | ||
51 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
52 | |||
53 | if tensor.shape[-1] != self.embedding_dim: | ||
54 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
55 | |||
56 | id = self.mapping[input_ids] | ||
57 | |||
58 | if id == -1: | ||
59 | id = len(self.params) | ||
60 | self.mapping[input_ids] = id | ||
61 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | ||
62 | |||
63 | self.params[id] = tensor | ||
64 | |||
65 | def unset(self, input_ids: torch.LongTensor): | ||
66 | self.mapping[input_ids] = -1 | ||