diff options
-rw-r--r-- | models/clip/embeddings.py | 41 | ||||
-rw-r--r-- | models/lora.py | 77 | ||||
-rw-r--r-- | models/sparse.py | 110 | ||||
-rw-r--r-- | train_lora.py | 11 | ||||
-rw-r--r-- | train_ti.py | 21 | ||||
-rw-r--r-- | training/functional.py | 7 |
6 files changed, 177 insertions, 90 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -10,23 +10,21 @@ from transformers import CLIPTextModel | |||
10 | from transformers.models.clip import CLIPTextConfig | 10 | from transformers.models.clip import CLIPTextConfig |
11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
12 | 12 | ||
13 | from models.lora import LoraEmbedding | 13 | from models.sparse import SparseEmbedding |
14 | 14 | ||
15 | 15 | ||
16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): | 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): |
18 | super().__init__(config) | 18 | super().__init__(config) |
19 | 19 | ||
20 | self.position_embedding = embeddings.position_embedding | 20 | self.position_embedding = embeddings.position_embedding |
21 | self.initializer_factor = config.initializer_factor | 21 | self.initializer_factor = config.initializer_factor |
22 | self.token_embedding = LoraEmbedding( | 22 | self.token_embedding = SparseEmbedding( |
23 | self.token_embedding.num_embeddings, | 23 | self.token_embedding.num_embeddings, |
24 | self.token_embedding.embedding_dim, | 24 | self.token_embedding.embedding_dim, |
25 | r, | 25 | alpha, |
26 | lora_alpha, | 26 | dropout, |
27 | lora_dropout, | ||
28 | ) | 27 | ) |
29 | |||
30 | self.token_embedding.weight = embeddings.token_embedding.weight | 28 | self.token_embedding.weight = embeddings.token_embedding.weight |
31 | 29 | ||
32 | def resize(self, size: int): | 30 | def resize(self, size: int): |
@@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
82 | 80 | ||
83 | return self.token_embedding(input_ids) | 81 | return self.token_embedding(input_ids) |
84 | 82 | ||
85 | def forward( | ||
86 | self, | ||
87 | input_ids: Optional[torch.LongTensor] = None, | ||
88 | position_ids: Optional[torch.LongTensor] = None, | ||
89 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
90 | ) -> torch.Tensor: | ||
91 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
92 | |||
93 | if position_ids is None: | ||
94 | position_ids = self.position_ids[:, :seq_length] | ||
95 | |||
96 | if inputs_embeds is None: | ||
97 | inputs_embeds = self.get_embed(input_ids) | ||
98 | |||
99 | position_embeddings = self.position_embedding(position_ids) | ||
100 | embeddings = inputs_embeds + position_embeddings | ||
101 | |||
102 | return embeddings | ||
103 | |||
104 | 83 | ||
105 | def patch_managed_embeddings( | 84 | def patch_managed_embeddings( |
106 | text_encoder: CLIPTextModel, | 85 | text_encoder: CLIPTextModel, |
107 | r: int = 8, | 86 | alpha: int = 8, |
108 | lora_alpha: int = 8, | 87 | dropout: float = 0.0 |
109 | lora_dropout: float = 0.0 | ||
110 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
111 | text_embeddings = ManagedCLIPTextEmbeddings( | 89 | text_embeddings = ManagedCLIPTextEmbeddings( |
112 | text_encoder.config, | 90 | text_encoder.config, |
113 | text_encoder.text_model.embeddings, | 91 | text_encoder.text_model.embeddings, |
114 | r, | 92 | alpha, |
115 | lora_alpha, | 93 | dropout |
116 | lora_dropout | ||
117 | ) | 94 | ) |
118 | text_encoder.text_model.embeddings = text_embeddings | 95 | text_encoder.text_model.embeddings = text_embeddings |
119 | return text_embeddings | 96 | return text_embeddings |
diff --git a/models/lora.py b/models/lora.py index 01a540b..e506cff 100644 --- a/models/lora.py +++ b/models/lora.py | |||
@@ -1,8 +1,8 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | import math | ||
2 | 3 | ||
3 | import torch | 4 | import torch |
4 | import torch.nn as nn | 5 | import torch.nn as nn |
5 | import torch.nn.functional as F | ||
6 | 6 | ||
7 | 7 | ||
8 | class LoraLayer(): | 8 | class LoraLayer(): |
@@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
43 | ) | 43 | ) |
44 | 44 | ||
45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) |
46 | self.trainable_ids -= 1 | ||
47 | 46 | ||
48 | if r > 0: | 47 | self.lora_A = nn.ParameterList() |
49 | self.lora_A = nn.ParameterList() | 48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) |
50 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | 49 | self.scaling = self.lora_alpha / self.r |
51 | self.scaling = self.lora_alpha / self.r | 50 | self.weight.requires_grad = False |
52 | self.weight.requires_grad = False | ||
53 | 51 | ||
54 | self.reset_parameters() | 52 | self.reset_parameters() |
55 | 53 | ||
@@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
70 | else: | 68 | else: |
71 | nn.init.zeros_(new_emb.weight.data) | 69 | nn.init.zeros_(new_emb.weight.data) |
72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | 70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] |
73 | new_emb.lora_A = self.lora_A | 71 | for param in self.lora_A: |
74 | new_emb.lora_B = self.lora_B | 72 | new_emb.lora_A.append(param) |
73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | 74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] |
76 | 75 | ||
77 | return new_emb | 76 | return new_emb |
@@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
87 | n2 = n1 + new_ids.shape[0] | 86 | n2 = n1 + new_ids.shape[0] |
88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
89 | for _ in new_ids: | 88 | for _ in new_ids: |
90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) | 89 | w = self.weight.new_zeros(self.r) |
90 | self.lora_A.append(w) | ||
91 | |||
92 | if len(self.lora_A) > 1: | ||
93 | elems = torch.stack([param for param in self.lora_A]) | ||
94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
91 | 95 | ||
92 | def get_weights(self, input_ids: torch.Tensor): | 96 | def get_weights(self, input_ids: torch.Tensor): |
93 | if len(input_ids.shape) != 1: | 97 | if len(input_ids.shape) != 1: |
94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | 98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) |
95 | 99 | ||
96 | trainable_ids = self.trainable_ids[input_ids] | ||
97 | mask = ~(trainable_ids == -1) | ||
98 | trainable_ids = trainable_ids[mask] | ||
99 | |||
100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) |
101 | elems = [self.lora_A[id] for id in trainable_ids] | ||
102 | 101 | ||
103 | if len(elems) != 0: | 102 | if not self.merged: |
104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 103 | trainable_ids = self.trainable_ids[input_ids] |
105 | weights[mask] = w.to(dtype=weights.dtype) | 104 | mask = ~(trainable_ids == -1) |
105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
106 | |||
107 | if len(elems) != 0: | ||
108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
109 | weights[mask] = w.to(dtype=weights.dtype) | ||
106 | 110 | ||
107 | return weights | 111 | return weights |
108 | 112 | ||
109 | def persist(self): | 113 | def persist(self): |
110 | if self.r > 0: | 114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 115 | self.trainable_ids[:] = -1 |
112 | self.weight.data += weights | 116 | self.lora_A = nn.ParameterList() |
113 | self.trainable_ids[:] = -1 | 117 | nn.init.zeros_(self.lora_B.weight) |
114 | self.lora_A = nn.ParameterList() | ||
115 | 118 | ||
116 | def reset_parameters(self): | 119 | def reset_parameters(self): |
117 | nn.Embedding.reset_parameters(self) | 120 | nn.Embedding.reset_parameters(self) |
118 | if hasattr(self, 'lora_A'): | 121 | if hasattr(self, "lora_A"): |
119 | self.trainable_ids[:] = -1 | 122 | self.trainable_ids[:] = -1 |
120 | self.lora_A = nn.ParameterList() | 123 | self.lora_A = nn.ParameterList() |
121 | nn.init.zeros_(self.lora_B.weight) | 124 | nn.init.zeros_(self.lora_B.weight) |
122 | 125 | ||
123 | def train(self, mode: bool = True): | 126 | def train(self, mode: bool = True): |
124 | nn.Embedding.train(self, mode) | 127 | nn.Embedding.train(self, mode) |
125 | if self.merge_weights and self.merged: | 128 | self.lora_A.train(mode) |
126 | if self.r > 0: | 129 | self.lora_B.train(mode) |
127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 130 | if not mode and self.merge_weights and not self.merged: |
128 | self.weight.data -= weights | 131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
132 | self.merged = True | ||
133 | elif self.merge_weights and self.merged: | ||
134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
129 | self.merged = False | 135 | self.merged = False |
130 | 136 | ||
131 | def eval(self): | 137 | def eval(self): |
132 | nn.Embedding.eval(self) | 138 | nn.Embedding.eval(self) |
133 | if self.merge_weights and not self.merged: | 139 | self.lora_A.eval() |
134 | if self.r > 0: | 140 | self.lora_B.eval() |
135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
136 | self.weight.data += weights | ||
137 | self.merged = True | ||
138 | 141 | ||
139 | def forward(self, input_ids: torch.LongTensor): | 142 | def forward(self, input_ids: torch.LongTensor): |
140 | result = nn.Embedding.forward(self, input_ids) | 143 | result = nn.Embedding.forward(self, input_ids) |
141 | 144 | result += self.get_weights(input_ids) | |
142 | if self.r > 0 and not self.merged: | ||
143 | weights = self.get_weights(input_ids) | ||
144 | result += weights | ||
145 | |||
146 | return result | 145 | return result |
diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..bd45696 --- /dev/null +++ b/models/sparse.py | |||
@@ -0,0 +1,110 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | |||
7 | class SparseEmbedding(nn.Embedding): | ||
8 | def __init__( | ||
9 | self, | ||
10 | num_embeddings: int, | ||
11 | embedding_dim: int, | ||
12 | alpha: int = 1, | ||
13 | dropout: float = 0.0, | ||
14 | **kwargs | ||
15 | ): | ||
16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
17 | |||
18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
19 | |||
20 | self.trainable = nn.ParameterList() | ||
21 | self.scaling = alpha | ||
22 | self.dropout_p = dropout | ||
23 | self.weight.requires_grad = False | ||
24 | |||
25 | if dropout > 0.: | ||
26 | self.dropout = nn.Dropout(p=dropout) | ||
27 | else: | ||
28 | self.dropout = nn.Identity() | ||
29 | |||
30 | self.reset_parameters() | ||
31 | |||
32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
33 | n = min(self.num_embeddings, new_num_embeddings) | ||
34 | |||
35 | new_emb = SparseEmbedding( | ||
36 | new_num_embeddings, | ||
37 | self.embedding_dim, | ||
38 | self.scaling, | ||
39 | self.dropout_p, | ||
40 | device=self.weight.device, | ||
41 | dtype=self.weight.dtype | ||
42 | ) | ||
43 | if initializer_factor is not None: | ||
44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
45 | else: | ||
46 | nn.init.zeros_(new_emb.weight.data) | ||
47 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
48 | for param in self.trainable: | ||
49 | new_emb.trainable.append(param) | ||
50 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
51 | |||
52 | return new_emb | ||
53 | |||
54 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
55 | trainable_ids = self.trainable_ids[input_ids] | ||
56 | new_ids = input_ids[trainable_ids == -1] | ||
57 | |||
58 | if new_ids.shape[0] == 0: | ||
59 | return | ||
60 | |||
61 | n1 = len(self.trainable) | ||
62 | n2 = n1 + new_ids.shape[0] | ||
63 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
64 | for _ in new_ids: | ||
65 | self.trainable.append(self.weight.new_zeros(self.embedding_dim)) | ||
66 | |||
67 | def get_weights(self, input_ids: torch.Tensor): | ||
68 | original_shape = input_ids.shape | ||
69 | |||
70 | if len(input_ids.shape) != 1: | ||
71 | input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) | ||
72 | |||
73 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
74 | |||
75 | trainable_ids = self.trainable_ids[input_ids] | ||
76 | mask = ~(trainable_ids == -1) | ||
77 | elems = [self.trainable[id] for id in trainable_ids[mask]] | ||
78 | |||
79 | if len(elems) != 0: | ||
80 | w = self.dropout(torch.stack(elems)) * self.scaling | ||
81 | weights[mask] = w.to(dtype=weights.dtype) | ||
82 | |||
83 | if len(original_shape) != 1: | ||
84 | weights = weights.view(original_shape[0], original_shape[1], -1) | ||
85 | |||
86 | return weights | ||
87 | |||
88 | def persist(self): | ||
89 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
90 | self.trainable_ids[:] = -1 | ||
91 | self.trainable = nn.ParameterList() | ||
92 | |||
93 | def reset_parameters(self): | ||
94 | nn.Embedding.reset_parameters(self) | ||
95 | if hasattr(self, "trainable"): | ||
96 | self.trainable_ids[:] = -1 | ||
97 | self.trainable = nn.ParameterList() | ||
98 | |||
99 | def train(self, mode: bool = True): | ||
100 | nn.Embedding.train(self, mode) | ||
101 | self.trainable.train(mode) | ||
102 | |||
103 | def eval(self): | ||
104 | nn.Embedding.eval(self) | ||
105 | self.trainable.eval() | ||
106 | |||
107 | def forward(self, input_ids: torch.LongTensor): | ||
108 | result = nn.Embedding.forward(self, input_ids) | ||
109 | result += self.get_weights(input_ids) | ||
110 | return result | ||
diff --git a/train_lora.py b/train_lora.py index d5dde02..5c78664 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -507,6 +507,12 @@ def parse_args(): | |||
507 | help="The weight of prior preservation loss." | 507 | help="The weight of prior preservation loss." |
508 | ) | 508 | ) |
509 | parser.add_argument( | 509 | parser.add_argument( |
510 | "--emb_alpha", | ||
511 | type=float, | ||
512 | default=1.0, | ||
513 | help="Embedding alpha" | ||
514 | ) | ||
515 | parser.add_argument( | ||
510 | "--emb_dropout", | 516 | "--emb_dropout", |
511 | type=float, | 517 | type=float, |
512 | default=0, | 518 | default=0, |
@@ -660,7 +666,10 @@ def main(): | |||
660 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
661 | 667 | ||
662 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
663 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
670 | args.emb_alpha, | ||
671 | args.emb_dropout | ||
672 | ) | ||
664 | 673 | ||
665 | unet_config = LoraConfig( | 674 | unet_config = LoraConfig( |
666 | r=args.lora_r, | 675 | r=args.lora_r, |
diff --git a/train_ti.py b/train_ti.py index 7f5fb49..45e730a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -484,19 +484,13 @@ def parse_args(): | |||
484 | help="The weight of prior preservation loss." | 484 | help="The weight of prior preservation loss." |
485 | ) | 485 | ) |
486 | parser.add_argument( | 486 | parser.add_argument( |
487 | "--lora_r", | 487 | "--emb_alpha", |
488 | type=int, | 488 | type=float, |
489 | default=8, | 489 | default=1.0, |
490 | help="Lora rank, only used if use_lora is True" | 490 | help="Embedding alpha" |
491 | ) | ||
492 | parser.add_argument( | ||
493 | "--lora_alpha", | ||
494 | type=int, | ||
495 | default=32, | ||
496 | help="Lora alpha, only used if use_lora is True" | ||
497 | ) | 491 | ) |
498 | parser.add_argument( | 492 | parser.add_argument( |
499 | "--lora_dropout", | 493 | "--emb_dropout", |
500 | type=float, | 494 | type=float, |
501 | default=0, | 495 | default=0, |
502 | help="Embedding dropout probability.", | 496 | help="Embedding dropout probability.", |
@@ -669,9 +663,8 @@ def main(): | |||
669 | 663 | ||
670 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 664 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
671 | args.pretrained_model_name_or_path, | 665 | args.pretrained_model_name_or_path, |
672 | args.lora_r, | 666 | args.emb_alpha, |
673 | args.lora_alpha, | 667 | args.emb_dropout |
674 | args.lora_dropout | ||
675 | ) | 668 | ) |
676 | 669 | ||
677 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 670 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
diff --git a/training/functional.py b/training/functional.py index 1fdfdc8..2da0f69 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -68,9 +68,8 @@ class TrainingStrategy(): | |||
68 | 68 | ||
69 | def get_models( | 69 | def get_models( |
70 | pretrained_model_name_or_path: str, | 70 | pretrained_model_name_or_path: str, |
71 | emb_r: int = 8, | 71 | emb_alpha: int = 8, |
72 | emb_lora_alpha: int = 8, | 72 | emb_dropout: float = 0.0 |
73 | emb_lora_dropout: float = 0.0 | ||
74 | ): | 73 | ): |
75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 74 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 75 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -80,7 +79,7 @@ def get_models( | |||
80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 79 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
81 | pretrained_model_name_or_path, subfolder='scheduler') | 80 | pretrained_model_name_or_path, subfolder='scheduler') |
82 | 81 | ||
83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) | 82 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) |
84 | 83 | ||
85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 84 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
86 | 85 | ||