diff options
| -rw-r--r-- | models/clip/embeddings.py | 41 | ||||
| -rw-r--r-- | models/lora.py | 77 | ||||
| -rw-r--r-- | models/sparse.py | 110 | ||||
| -rw-r--r-- | train_lora.py | 11 | ||||
| -rw-r--r-- | train_ti.py | 21 | ||||
| -rw-r--r-- | training/functional.py | 7 |
6 files changed, 177 insertions, 90 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index d02ccc3..8aaea8f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -10,23 +10,21 @@ from transformers import CLIPTextModel | |||
| 10 | from transformers.models.clip import CLIPTextConfig | 10 | from transformers.models.clip import CLIPTextConfig |
| 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 11 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
| 12 | 12 | ||
| 13 | from models.lora import LoraEmbedding | 13 | from models.sparse import SparseEmbedding |
| 14 | 14 | ||
| 15 | 15 | ||
| 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): | 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): |
| 18 | super().__init__(config) | 18 | super().__init__(config) |
| 19 | 19 | ||
| 20 | self.position_embedding = embeddings.position_embedding | 20 | self.position_embedding = embeddings.position_embedding |
| 21 | self.initializer_factor = config.initializer_factor | 21 | self.initializer_factor = config.initializer_factor |
| 22 | self.token_embedding = LoraEmbedding( | 22 | self.token_embedding = SparseEmbedding( |
| 23 | self.token_embedding.num_embeddings, | 23 | self.token_embedding.num_embeddings, |
| 24 | self.token_embedding.embedding_dim, | 24 | self.token_embedding.embedding_dim, |
| 25 | r, | 25 | alpha, |
| 26 | lora_alpha, | 26 | dropout, |
| 27 | lora_dropout, | ||
| 28 | ) | 27 | ) |
| 29 | |||
| 30 | self.token_embedding.weight = embeddings.token_embedding.weight | 28 | self.token_embedding.weight = embeddings.token_embedding.weight |
| 31 | 29 | ||
| 32 | def resize(self, size: int): | 30 | def resize(self, size: int): |
| @@ -82,38 +80,17 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 82 | 80 | ||
| 83 | return self.token_embedding(input_ids) | 81 | return self.token_embedding(input_ids) |
| 84 | 82 | ||
| 85 | def forward( | ||
| 86 | self, | ||
| 87 | input_ids: Optional[torch.LongTensor] = None, | ||
| 88 | position_ids: Optional[torch.LongTensor] = None, | ||
| 89 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 90 | ) -> torch.Tensor: | ||
| 91 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 92 | |||
| 93 | if position_ids is None: | ||
| 94 | position_ids = self.position_ids[:, :seq_length] | ||
| 95 | |||
| 96 | if inputs_embeds is None: | ||
| 97 | inputs_embeds = self.get_embed(input_ids) | ||
| 98 | |||
| 99 | position_embeddings = self.position_embedding(position_ids) | ||
| 100 | embeddings = inputs_embeds + position_embeddings | ||
| 101 | |||
| 102 | return embeddings | ||
| 103 | |||
| 104 | 83 | ||
| 105 | def patch_managed_embeddings( | 84 | def patch_managed_embeddings( |
| 106 | text_encoder: CLIPTextModel, | 85 | text_encoder: CLIPTextModel, |
| 107 | r: int = 8, | 86 | alpha: int = 8, |
| 108 | lora_alpha: int = 8, | 87 | dropout: float = 0.0 |
| 109 | lora_dropout: float = 0.0 | ||
| 110 | ) -> ManagedCLIPTextEmbeddings: | 88 | ) -> ManagedCLIPTextEmbeddings: |
| 111 | text_embeddings = ManagedCLIPTextEmbeddings( | 89 | text_embeddings = ManagedCLIPTextEmbeddings( |
| 112 | text_encoder.config, | 90 | text_encoder.config, |
| 113 | text_encoder.text_model.embeddings, | 91 | text_encoder.text_model.embeddings, |
| 114 | r, | 92 | alpha, |
| 115 | lora_alpha, | 93 | dropout |
| 116 | lora_dropout | ||
| 117 | ) | 94 | ) |
| 118 | text_encoder.text_model.embeddings = text_embeddings | 95 | text_encoder.text_model.embeddings = text_embeddings |
| 119 | return text_embeddings | 96 | return text_embeddings |
diff --git a/models/lora.py b/models/lora.py index 01a540b..e506cff 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -1,8 +1,8 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | import math | ||
| 2 | 3 | ||
| 3 | import torch | 4 | import torch |
| 4 | import torch.nn as nn | 5 | import torch.nn as nn |
| 5 | import torch.nn.functional as F | ||
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | class LoraLayer(): | 8 | class LoraLayer(): |
| @@ -42,14 +42,12 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| 43 | ) | 43 | ) |
| 44 | 44 | ||
| 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long)) | 45 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) |
| 46 | self.trainable_ids -= 1 | ||
| 47 | 46 | ||
| 48 | if r > 0: | 47 | self.lora_A = nn.ParameterList() |
| 49 | self.lora_A = nn.ParameterList() | 48 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) |
| 50 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) | 49 | self.scaling = self.lora_alpha / self.r |
| 51 | self.scaling = self.lora_alpha / self.r | 50 | self.weight.requires_grad = False |
| 52 | self.weight.requires_grad = False | ||
| 53 | 51 | ||
| 54 | self.reset_parameters() | 52 | self.reset_parameters() |
| 55 | 53 | ||
| @@ -70,8 +68,9 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 70 | else: | 68 | else: |
| 71 | nn.init.zeros_(new_emb.weight.data) | 69 | nn.init.zeros_(new_emb.weight.data) |
| 72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | 70 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] |
| 73 | new_emb.lora_A = self.lora_A | 71 | for param in self.lora_A: |
| 74 | new_emb.lora_B = self.lora_B | 72 | new_emb.lora_A.append(param) |
| 73 | new_emb.lora_B.weight[:].data = self.lora_B.weight[:].data | ||
| 75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | 74 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] |
| 76 | 75 | ||
| 77 | return new_emb | 76 | return new_emb |
| @@ -87,60 +86,60 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 87 | n2 = n1 + new_ids.shape[0] | 86 | n2 = n1 + new_ids.shape[0] |
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 87 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
| 89 | for _ in new_ids: | 88 | for _ in new_ids: |
| 90 | self.lora_A.append(nn.Parameter(self.weight.new_zeros(self.r))) | 89 | w = self.weight.new_zeros(self.r) |
| 90 | self.lora_A.append(w) | ||
| 91 | |||
| 92 | if len(self.lora_A) > 1: | ||
| 93 | elems = torch.stack([param for param in self.lora_A]) | ||
| 94 | nn.init.kaiming_uniform_(elems, a=math.sqrt(5)) | ||
| 91 | 95 | ||
| 92 | def get_weights(self, input_ids: torch.Tensor): | 96 | def get_weights(self, input_ids: torch.Tensor): |
| 93 | if len(input_ids.shape) != 1: | 97 | if len(input_ids.shape) != 1: |
| 94 | return torch.stack([self.get_weights(batch) for batch in input_ids]) | 98 | return torch.stack([self.get_weights(batch) for batch in input_ids]) |
| 95 | 99 | ||
| 96 | trainable_ids = self.trainable_ids[input_ids] | ||
| 97 | mask = ~(trainable_ids == -1) | ||
| 98 | trainable_ids = trainable_ids[mask] | ||
| 99 | |||
| 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | 100 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) |
| 101 | elems = [self.lora_A[id] for id in trainable_ids] | ||
| 102 | 101 | ||
| 103 | if len(elems) != 0: | 102 | if not self.merged: |
| 104 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | 103 | trainable_ids = self.trainable_ids[input_ids] |
| 105 | weights[mask] = w.to(dtype=weights.dtype) | 104 | mask = ~(trainable_ids == -1) |
| 105 | elems = [self.lora_A[id] for id in trainable_ids[mask]] | ||
| 106 | |||
| 107 | if len(elems) != 0: | ||
| 108 | w = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
| 109 | weights[mask] = w.to(dtype=weights.dtype) | ||
| 106 | 110 | ||
| 107 | return weights | 111 | return weights |
| 108 | 112 | ||
| 109 | def persist(self): | 113 | def persist(self): |
| 110 | if self.r > 0: | 114 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 111 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 115 | self.trainable_ids[:] = -1 |
| 112 | self.weight.data += weights | 116 | self.lora_A = nn.ParameterList() |
| 113 | self.trainable_ids[:] = -1 | 117 | nn.init.zeros_(self.lora_B.weight) |
| 114 | self.lora_A = nn.ParameterList() | ||
| 115 | 118 | ||
| 116 | def reset_parameters(self): | 119 | def reset_parameters(self): |
| 117 | nn.Embedding.reset_parameters(self) | 120 | nn.Embedding.reset_parameters(self) |
| 118 | if hasattr(self, 'lora_A'): | 121 | if hasattr(self, "lora_A"): |
| 119 | self.trainable_ids[:] = -1 | 122 | self.trainable_ids[:] = -1 |
| 120 | self.lora_A = nn.ParameterList() | 123 | self.lora_A = nn.ParameterList() |
| 121 | nn.init.zeros_(self.lora_B.weight) | 124 | nn.init.zeros_(self.lora_B.weight) |
| 122 | 125 | ||
| 123 | def train(self, mode: bool = True): | 126 | def train(self, mode: bool = True): |
| 124 | nn.Embedding.train(self, mode) | 127 | nn.Embedding.train(self, mode) |
| 125 | if self.merge_weights and self.merged: | 128 | self.lora_A.train(mode) |
| 126 | if self.r > 0: | 129 | self.lora_B.train(mode) |
| 127 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | 130 | if not mode and self.merge_weights and not self.merged: |
| 128 | self.weight.data -= weights | 131 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 132 | self.merged = True | ||
| 133 | elif self.merge_weights and self.merged: | ||
| 134 | self.weight.data -= self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 129 | self.merged = False | 135 | self.merged = False |
| 130 | 136 | ||
| 131 | def eval(self): | 137 | def eval(self): |
| 132 | nn.Embedding.eval(self) | 138 | nn.Embedding.eval(self) |
| 133 | if self.merge_weights and not self.merged: | 139 | self.lora_A.eval() |
| 134 | if self.r > 0: | 140 | self.lora_B.eval() |
| 135 | weights = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 136 | self.weight.data += weights | ||
| 137 | self.merged = True | ||
| 138 | 141 | ||
| 139 | def forward(self, input_ids: torch.LongTensor): | 142 | def forward(self, input_ids: torch.LongTensor): |
| 140 | result = nn.Embedding.forward(self, input_ids) | 143 | result = nn.Embedding.forward(self, input_ids) |
| 141 | 144 | result += self.get_weights(input_ids) | |
| 142 | if self.r > 0 and not self.merged: | ||
| 143 | weights = self.get_weights(input_ids) | ||
| 144 | result += weights | ||
| 145 | |||
| 146 | return result | 145 | return result |
diff --git a/models/sparse.py b/models/sparse.py new file mode 100644 index 0000000..bd45696 --- /dev/null +++ b/models/sparse.py | |||
| @@ -0,0 +1,110 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | |||
| 7 | class SparseEmbedding(nn.Embedding): | ||
| 8 | def __init__( | ||
| 9 | self, | ||
| 10 | num_embeddings: int, | ||
| 11 | embedding_dim: int, | ||
| 12 | alpha: int = 1, | ||
| 13 | dropout: float = 0.0, | ||
| 14 | **kwargs | ||
| 15 | ): | ||
| 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
| 17 | |||
| 18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | ||
| 19 | |||
| 20 | self.trainable = nn.ParameterList() | ||
| 21 | self.scaling = alpha | ||
| 22 | self.dropout_p = dropout | ||
| 23 | self.weight.requires_grad = False | ||
| 24 | |||
| 25 | if dropout > 0.: | ||
| 26 | self.dropout = nn.Dropout(p=dropout) | ||
| 27 | else: | ||
| 28 | self.dropout = nn.Identity() | ||
| 29 | |||
| 30 | self.reset_parameters() | ||
| 31 | |||
| 32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
| 33 | n = min(self.num_embeddings, new_num_embeddings) | ||
| 34 | |||
| 35 | new_emb = SparseEmbedding( | ||
| 36 | new_num_embeddings, | ||
| 37 | self.embedding_dim, | ||
| 38 | self.scaling, | ||
| 39 | self.dropout_p, | ||
| 40 | device=self.weight.device, | ||
| 41 | dtype=self.weight.dtype | ||
| 42 | ) | ||
| 43 | if initializer_factor is not None: | ||
| 44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 45 | else: | ||
| 46 | nn.init.zeros_(new_emb.weight.data) | ||
| 47 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
| 48 | for param in self.trainable: | ||
| 49 | new_emb.trainable.append(param) | ||
| 50 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
| 51 | |||
| 52 | return new_emb | ||
| 53 | |||
| 54 | def mark_trainable(self, input_ids: torch.LongTensor): | ||
| 55 | trainable_ids = self.trainable_ids[input_ids] | ||
| 56 | new_ids = input_ids[trainable_ids == -1] | ||
| 57 | |||
| 58 | if new_ids.shape[0] == 0: | ||
| 59 | return | ||
| 60 | |||
| 61 | n1 = len(self.trainable) | ||
| 62 | n2 = n1 + new_ids.shape[0] | ||
| 63 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | ||
| 64 | for _ in new_ids: | ||
| 65 | self.trainable.append(self.weight.new_zeros(self.embedding_dim)) | ||
| 66 | |||
| 67 | def get_weights(self, input_ids: torch.Tensor): | ||
| 68 | original_shape = input_ids.shape | ||
| 69 | |||
| 70 | if len(input_ids.shape) != 1: | ||
| 71 | input_ids = input_ids.view(input_ids.shape[0] * input_ids.shape[1]) | ||
| 72 | |||
| 73 | weights = self.weight.new_zeros((input_ids.shape[0], self.embedding_dim)) | ||
| 74 | |||
| 75 | trainable_ids = self.trainable_ids[input_ids] | ||
| 76 | mask = ~(trainable_ids == -1) | ||
| 77 | elems = [self.trainable[id] for id in trainable_ids[mask]] | ||
| 78 | |||
| 79 | if len(elems) != 0: | ||
| 80 | w = self.dropout(torch.stack(elems)) * self.scaling | ||
| 81 | weights[mask] = w.to(dtype=weights.dtype) | ||
| 82 | |||
| 83 | if len(original_shape) != 1: | ||
| 84 | weights = weights.view(original_shape[0], original_shape[1], -1) | ||
| 85 | |||
| 86 | return weights | ||
| 87 | |||
| 88 | def persist(self): | ||
| 89 | self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 90 | self.trainable_ids[:] = -1 | ||
| 91 | self.trainable = nn.ParameterList() | ||
| 92 | |||
| 93 | def reset_parameters(self): | ||
| 94 | nn.Embedding.reset_parameters(self) | ||
| 95 | if hasattr(self, "trainable"): | ||
| 96 | self.trainable_ids[:] = -1 | ||
| 97 | self.trainable = nn.ParameterList() | ||
| 98 | |||
| 99 | def train(self, mode: bool = True): | ||
| 100 | nn.Embedding.train(self, mode) | ||
| 101 | self.trainable.train(mode) | ||
| 102 | |||
| 103 | def eval(self): | ||
| 104 | nn.Embedding.eval(self) | ||
| 105 | self.trainable.eval() | ||
| 106 | |||
| 107 | def forward(self, input_ids: torch.LongTensor): | ||
| 108 | result = nn.Embedding.forward(self, input_ids) | ||
| 109 | result += self.get_weights(input_ids) | ||
| 110 | return result | ||
diff --git a/train_lora.py b/train_lora.py index d5dde02..5c78664 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -507,6 +507,12 @@ def parse_args(): | |||
| 507 | help="The weight of prior preservation loss." | 507 | help="The weight of prior preservation loss." |
| 508 | ) | 508 | ) |
| 509 | parser.add_argument( | 509 | parser.add_argument( |
| 510 | "--emb_alpha", | ||
| 511 | type=float, | ||
| 512 | default=1.0, | ||
| 513 | help="Embedding alpha" | ||
| 514 | ) | ||
| 515 | parser.add_argument( | ||
| 510 | "--emb_dropout", | 516 | "--emb_dropout", |
| 511 | type=float, | 517 | type=float, |
| 512 | default=0, | 518 | default=0, |
| @@ -660,7 +666,10 @@ def main(): | |||
| 660 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
| 661 | 667 | ||
| 662 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 663 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
| 670 | args.emb_alpha, | ||
| 671 | args.emb_dropout | ||
| 672 | ) | ||
| 664 | 673 | ||
| 665 | unet_config = LoraConfig( | 674 | unet_config = LoraConfig( |
| 666 | r=args.lora_r, | 675 | r=args.lora_r, |
diff --git a/train_ti.py b/train_ti.py index 7f5fb49..45e730a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -484,19 +484,13 @@ def parse_args(): | |||
| 484 | help="The weight of prior preservation loss." | 484 | help="The weight of prior preservation loss." |
| 485 | ) | 485 | ) |
| 486 | parser.add_argument( | 486 | parser.add_argument( |
| 487 | "--lora_r", | 487 | "--emb_alpha", |
| 488 | type=int, | 488 | type=float, |
| 489 | default=8, | 489 | default=1.0, |
| 490 | help="Lora rank, only used if use_lora is True" | 490 | help="Embedding alpha" |
| 491 | ) | ||
| 492 | parser.add_argument( | ||
| 493 | "--lora_alpha", | ||
| 494 | type=int, | ||
| 495 | default=32, | ||
| 496 | help="Lora alpha, only used if use_lora is True" | ||
| 497 | ) | 491 | ) |
| 498 | parser.add_argument( | 492 | parser.add_argument( |
| 499 | "--lora_dropout", | 493 | "--emb_dropout", |
| 500 | type=float, | 494 | type=float, |
| 501 | default=0, | 495 | default=0, |
| 502 | help="Embedding dropout probability.", | 496 | help="Embedding dropout probability.", |
| @@ -669,9 +663,8 @@ def main(): | |||
| 669 | 663 | ||
| 670 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 664 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 671 | args.pretrained_model_name_or_path, | 665 | args.pretrained_model_name_or_path, |
| 672 | args.lora_r, | 666 | args.emb_alpha, |
| 673 | args.lora_alpha, | 667 | args.emb_dropout |
| 674 | args.lora_dropout | ||
| 675 | ) | 668 | ) |
| 676 | 669 | ||
| 677 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 670 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
diff --git a/training/functional.py b/training/functional.py index 1fdfdc8..2da0f69 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -68,9 +68,8 @@ class TrainingStrategy(): | |||
| 68 | 68 | ||
| 69 | def get_models( | 69 | def get_models( |
| 70 | pretrained_model_name_or_path: str, | 70 | pretrained_model_name_or_path: str, |
| 71 | emb_r: int = 8, | 71 | emb_alpha: int = 8, |
| 72 | emb_lora_alpha: int = 8, | 72 | emb_dropout: float = 0.0 |
| 73 | emb_lora_dropout: float = 0.0 | ||
| 74 | ): | 73 | ): |
| 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 74 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 75 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -80,7 +79,7 @@ def get_models( | |||
| 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 79 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 81 | pretrained_model_name_or_path, subfolder='scheduler') | 80 | pretrained_model_name_or_path, subfolder='scheduler') |
| 82 | 81 | ||
| 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) | 82 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) |
| 84 | 83 | ||
| 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 84 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 86 | 85 | ||
