diff options
| -rw-r--r-- | models/clip/embeddings.py | 76 | ||||
| -rw-r--r-- | models/lora.py | 131 | ||||
| -rw-r--r-- | models/sparse.py | 66 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 | ||||
| -rw-r--r-- | train_lora.py | 7 | ||||
| -rw-r--r-- | train_ti.py | 28 | ||||
| -rw-r--r-- | training/functional.py | 11 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 9 |
9 files changed, 212 insertions, 140 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -11,49 +11,27 @@ from transformers import CLIPTextModel | |||
| 11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
| 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
| 13 | 13 | ||
| 14 | from models.sparse import PseudoSparseEmbedding | 14 | from models.lora import LoraEmbedding |
| 15 | |||
| 16 | |||
| 17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | ||
| 18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | ||
| 19 | |||
| 20 | if old_num_embeddings == new_num_embeddings: | ||
| 21 | return old_embedding | ||
| 22 | |||
| 23 | n = min(old_num_embeddings, new_num_embeddings) | ||
| 24 | |||
| 25 | new_embedding = nn.Embedding( | ||
| 26 | new_num_embeddings, | ||
| 27 | old_embedding_dim, | ||
| 28 | device=old_embedding.weight.device, | ||
| 29 | dtype=old_embedding.weight.dtype | ||
| 30 | ) | ||
| 31 | if initializer_factor is not None: | ||
| 32 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 33 | else: | ||
| 34 | nn.init.zeros_(new_embedding.weight.data) | ||
| 35 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | ||
| 36 | return new_embedding | ||
| 37 | 15 | ||
| 38 | 16 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 17 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): | 18 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): |
| 41 | super().__init__(config) | 19 | super().__init__(config) |
| 42 | 20 | ||
| 43 | self.token_embedding = embeddings.token_embedding | ||
| 44 | self.position_embedding = embeddings.position_embedding | 21 | self.position_embedding = embeddings.position_embedding |
| 45 | self.initializer_factor = config.initializer_factor | 22 | self.initializer_factor = config.initializer_factor |
| 46 | 23 | self.token_embedding = LoraEmbedding( | |
| 47 | self.token_override_embedding = PseudoSparseEmbedding( | 24 | self.token_embedding.num_embeddings, |
| 48 | self.token_embedding.embedding_dim, | 25 | self.token_embedding.embedding_dim, |
| 49 | dropout_p=dropout_p, | 26 | r, |
| 50 | device=self.token_embedding.weight.device, | 27 | lora_alpha, |
| 51 | dtype=self.token_embedding.weight.dtype, | 28 | lora_dropout, |
| 52 | ) | 29 | ) |
| 53 | 30 | ||
| 31 | self.token_embedding.weight = embeddings.token_embedding.weight | ||
| 32 | |||
| 54 | def resize(self, size: int): | 33 | def resize(self, size: int): |
| 55 | self.token_override_embedding.resize(size) | 34 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) |
| 56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | ||
| 57 | 35 | ||
| 58 | def add_embed( | 36 | def add_embed( |
| 59 | self, | 37 | self, |
| @@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 88 | 66 | ||
| 89 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
| 90 | self.token_override_embedding.set(token_ids, initializer) | ||
| 91 | 68 | ||
| 92 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
| 93 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 98 | 75 | ||
| 99 | def persist(self): | 76 | def persist(self): |
| 100 | input_ids = torch.arange( | 77 | self.token_embedding.eval() |
| 101 | self.token_embedding.num_embeddings, | 78 | self.token_embedding.merged = False |
| 102 | device=self.token_override_embedding.mapping.device | ||
| 103 | ) | ||
| 104 | embs, mask = self.token_override_embedding(input_ids) | ||
| 105 | if embs is not None: | ||
| 106 | input_ids = input_ids[mask] | ||
| 107 | self.token_embedding.weight.data[input_ids] = embs | ||
| 108 | self.token_override_embedding.unset(input_ids) | ||
| 109 | 79 | ||
| 110 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 111 | if isinstance(input_ids, list): | 81 | if isinstance(input_ids, list): |
| 112 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 82 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 113 | 83 | ||
| 114 | embs = self.token_embedding(input_ids) | 84 | return self.token_embedding(input_ids) |
| 115 | embs_override, mask = self.token_override_embedding(input_ids) | ||
| 116 | if embs_override is not None: | ||
| 117 | embs[mask] = embs_override | ||
| 118 | |||
| 119 | return embs | ||
| 120 | 85 | ||
| 121 | def forward( | 86 | def forward( |
| 122 | self, | 87 | self, |
| @@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 138 | return embeddings | 103 | return embeddings |
| 139 | 104 | ||
| 140 | 105 | ||
| 141 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: | 106 | def patch_managed_embeddings( |
| 142 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) | 107 | text_encoder: CLIPTextModel, |
| 108 | r: int = 8, | ||
| 109 | lora_alpha: int = 8, | ||
| 110 | lora_dropout: float = 0.0 | ||
| 111 | ) -> ManagedCLIPTextEmbeddings: | ||
| 112 | text_embeddings = ManagedCLIPTextEmbeddings( | ||
| 113 | text_encoder.config, | ||
| 114 | text_encoder.text_model.embeddings, | ||
| 115 | r, | ||
| 116 | lora_alpha, | ||
| 117 | lora_dropout | ||
| 118 | ) | ||
| 143 | text_encoder.text_model.embeddings = text_embeddings | 119 | text_encoder.text_model.embeddings = text_embeddings |
| 144 | return text_embeddings | 120 | return text_embeddings |
diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py | |||
| @@ -0,0 +1,131 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | import torch.nn.functional as F | ||
| 6 | |||
| 7 | |||
| 8 | class LoraLayer(): | ||
| 9 | def __init__( | ||
| 10 | self, | ||
| 11 | r: int, | ||
| 12 | lora_alpha: int, | ||
| 13 | lora_dropout: float, | ||
| 14 | merge_weights: bool, | ||
| 15 | ): | ||
| 16 | self.r = r | ||
| 17 | self.lora_alpha = lora_alpha | ||
| 18 | self.lora_dropout_p = lora_dropout | ||
| 19 | |||
| 20 | if lora_dropout > 0.: | ||
| 21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
| 22 | else: | ||
| 23 | self.lora_dropout = nn.Identity() | ||
| 24 | |||
| 25 | self.merged = False | ||
| 26 | self.merge_weights = merge_weights | ||
| 27 | |||
| 28 | |||
| 29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
| 30 | def __init__( | ||
| 31 | self, | ||
| 32 | num_embeddings: int, | ||
| 33 | embedding_dim: int, | ||
| 34 | r: int = 0, | ||
| 35 | lora_alpha: int = 1, | ||
| 36 | lora_dropout: float = 0.0, | ||
| 37 | merge_weights: bool = True, | ||
| 38 | **kwargs | ||
| 39 | ): | ||
| 40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
| 41 | LoraLayer.__init__( | ||
| 42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
| 43 | ) | ||
| 44 | |||
| 45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | ||
| 46 | self.trainable_ids -= 1 | ||
| 47 | |||
| 48 | if r > 0: | ||
| 49 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) | ||
| 50 | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | ||
| 51 | self.scaling = self.lora_alpha / self.r | ||
| 52 | self.weight.requires_grad = False | ||
| 53 | |||
| 54 | self.reset_parameters() | ||
| 55 | |||
| 56 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
| 57 | n = min(self.num_embeddings, new_num_embeddings) | ||
| 58 | |||
| 59 | new_emb = LoraEmbedding( | ||
| 60 | new_num_embeddings, | ||
| 61 | self.embedding_dim, | ||
| 62 | self.r, | ||
| 63 | self.lora_alpha, | ||
| 64 | self.lora_dropout_p, | ||
| 65 | device=self.weight.device, | ||
| 66 | dtype=self.weight.dtype | ||
| 67 | ) | ||
| 68 | if initializer_factor is not None: | ||
| 69 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 70 | else: | ||
| 71 | nn.init.zeros_(new_emb.weight.data) | ||
| 72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
| 73 | new_emb.lora_A = self.lora_A | ||
| 74 | new_emb.lora_B = self.lora_B | ||
| 75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
| 76 | |||
| 77 | return new_emb | ||
| 78 | |||
| 79 | def mark_trainable(self, input_ids): | ||
| 80 | trainable_ids = self.trainable_ids[input_ids] | ||
| 81 | new_ids = trainable_ids[trainable_ids == -1] | ||
| 82 | |||
| 83 | if new_ids.shape[0] == 0: | ||
| 84 | return | ||
| 85 | |||
| 86 | n = self.trainable_ids.shape[0] | ||
| 87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | ||
| 88 | |||
| 89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | ||
| 90 | lora_A.data[:n] = self.lora_A.data | ||
| 91 | self.lora_A = lora_A | ||
| 92 | |||
| 93 | def reset_parameters(self): | ||
| 94 | nn.Embedding.reset_parameters(self) | ||
| 95 | if hasattr(self, 'lora_A'): | ||
| 96 | nn.init.zeros_(self.lora_A) | ||
| 97 | nn.init.normal_(self.lora_B) | ||
| 98 | |||
| 99 | def train(self, mode: bool = True): | ||
| 100 | nn.Embedding.train(self, mode) | ||
| 101 | if self.merge_weights and self.merged: | ||
| 102 | if self.r > 0: | ||
| 103 | mask = ~(self.trainable_ids == -1) | ||
| 104 | trainable_ids = self.trainable_ids[mask] | ||
| 105 | self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling | ||
| 106 | self.merged = False | ||
| 107 | |||
| 108 | def eval(self): | ||
| 109 | nn.Embedding.eval(self) | ||
| 110 | if self.merge_weights and not self.merged: | ||
| 111 | if self.r > 0: | ||
| 112 | mask = ~(self.trainable_ids == -1) | ||
| 113 | trainable_ids = self.trainable_ids[mask] | ||
| 114 | self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling | ||
| 115 | self.merged = True | ||
| 116 | |||
| 117 | def forward(self, input_ids: torch.Tensor): | ||
| 118 | result = nn.Embedding.forward(self, input_ids) | ||
| 119 | |||
| 120 | if self.r > 0 and not self.merged: | ||
| 121 | trainable_ids = self.trainable_ids[input_ids] | ||
| 122 | mask = ~(trainable_ids == -1) | ||
| 123 | trainable_ids = trainable_ids[mask] | ||
| 124 | |||
| 125 | after_A = F.embedding( | ||
| 126 | trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, | ||
| 127 | self.norm_type, self.scale_grad_by_freq, self.sparse | ||
| 128 | ) | ||
| 129 | result[mask] += (after_A @ self.lora_B.T) * self.scaling | ||
| 130 | |||
| 131 | return result | ||
diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null | |||
| @@ -1,66 +0,0 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 3 | import torch | ||
| 4 | import torch.nn as nn | ||
| 5 | |||
| 6 | |||
| 7 | class PseudoSparseEmbedding(nn.Module): | ||
| 8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): | ||
| 9 | super().__init__() | ||
| 10 | |||
| 11 | self.embedding_dim = embedding_dim | ||
| 12 | self.dtype = dtype | ||
| 13 | self.params = nn.ParameterList() | ||
| 14 | |||
| 15 | if dropout_p > 0.0: | ||
| 16 | self.dropout = nn.Dropout(p=dropout_p) | ||
| 17 | else: | ||
| 18 | self.dropout = nn.Identity() | ||
| 19 | |||
| 20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | ||
| 21 | |||
| 22 | def forward(self, input_ids: torch.LongTensor): | ||
| 23 | input_ids = input_ids.to(self.mapping.device) | ||
| 24 | ids = self.mapping[input_ids] | ||
| 25 | mask = ~(ids == -1) | ||
| 26 | |||
| 27 | if torch.all(~mask): | ||
| 28 | embs = None | ||
| 29 | else: | ||
| 30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) | ||
| 31 | |||
| 32 | return embs, mask | ||
| 33 | |||
| 34 | def resize(self, new_num_embeddings: int): | ||
| 35 | old_num_embeddings = self.mapping.shape[0] | ||
| 36 | n = min(old_num_embeddings, new_num_embeddings) | ||
| 37 | |||
| 38 | new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 | ||
| 39 | new_mapping[:n] = self.mapping[:n] | ||
| 40 | |||
| 41 | self.mapping = new_mapping | ||
| 42 | |||
| 43 | def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): | ||
| 44 | if len(input_ids.shape) != 0: | ||
| 45 | if tensor is not None: | ||
| 46 | return [self.set(id, t) for id, t in zip(input_ids, tensor)] | ||
| 47 | else: | ||
| 48 | return [self.set(id) for id in input_ids] | ||
| 49 | |||
| 50 | if tensor is None: | ||
| 51 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
| 52 | |||
| 53 | if tensor.shape[-1] != self.embedding_dim: | ||
| 54 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
| 55 | |||
| 56 | id = self.mapping[input_ids] | ||
| 57 | |||
| 58 | if id == -1: | ||
| 59 | id = len(self.params) | ||
| 60 | self.mapping[input_ids] = id | ||
| 61 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | ||
| 62 | |||
| 63 | self.params[id] = tensor | ||
| 64 | |||
| 65 | def unset(self, input_ids: torch.LongTensor): | ||
| 66 | self.mapping[input_ids] = -1 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 13ea2ac..a0dff54 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 591 | if callback is not None and i % callback_steps == 0: | 591 | if callback is not None and i % callback_steps == 0: |
| 592 | callback(i, t, latents) | 592 | callback(i, t, latents) |
| 593 | 593 | ||
| 594 | # 9. Post-processing | ||
| 595 | image = self.decode_latents(latents) | ||
| 596 | |||
| 597 | # 10. Run safety checker | ||
| 598 | has_nsfw_concept = None | 594 | has_nsfw_concept = None |
| 599 | 595 | ||
| 600 | # 11. Convert to PIL | 596 | if output_type == "latent": |
| 601 | if output_type == "pil": | 597 | image = latents |
| 598 | elif output_type == "pil": | ||
| 599 | # 9. Post-processing | ||
| 600 | image = self.decode_latents(latents) | ||
| 601 | |||
| 602 | # 10. Convert to PIL | ||
| 602 | image = self.numpy_to_pil(image) | 603 | image = self.numpy_to_pil(image) |
| 604 | else: | ||
| 605 | # 9. Post-processing | ||
| 606 | image = self.decode_latents(latents) | ||
| 607 | |||
| 608 | # Offload last model to CPU | ||
| 609 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||
| 610 | self.final_offload_hook.offload() | ||
| 603 | 611 | ||
| 604 | if not return_dict: | 612 | if not return_dict: |
| 605 | return (image, has_nsfw_concept) | 613 | return (image, has_nsfw_concept) |
diff --git a/train_lora.py b/train_lora.py index b8c7396..91bda5c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -387,7 +387,7 @@ def parse_args(): | |||
| 387 | parser.add_argument( | 387 | parser.add_argument( |
| 388 | "--optimizer", | 388 | "--optimizer", |
| 389 | type=str, | 389 | type=str, |
| 390 | default="dadan", | 390 | default="adan", |
| 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 392 | help='Optimizer to use' | 392 | help='Optimizer to use' |
| 393 | ) | 393 | ) |
| @@ -412,7 +412,7 @@ def parse_args(): | |||
| 412 | parser.add_argument( | 412 | parser.add_argument( |
| 413 | "--adam_weight_decay", | 413 | "--adam_weight_decay", |
| 414 | type=float, | 414 | type=float, |
| 415 | default=1e-2, | 415 | default=2e-2, |
| 416 | help="Weight decay to use." | 416 | help="Weight decay to use." |
| 417 | ) | 417 | ) |
| 418 | parser.add_argument( | 418 | parser.add_argument( |
| @@ -780,6 +780,7 @@ def main(): | |||
| 780 | timm.optim.Adan, | 780 | timm.optim.Adan, |
| 781 | weight_decay=args.adam_weight_decay, | 781 | weight_decay=args.adam_weight_decay, |
| 782 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
| 783 | no_prox=True, | ||
| 783 | ) | 784 | ) |
| 784 | elif args.optimizer == 'lion': | 785 | elif args.optimizer == 'lion': |
| 785 | try: | 786 | try: |
| @@ -961,7 +962,7 @@ def main(): | |||
| 961 | 962 | ||
| 962 | if len(args.placeholder_tokens) != 0: | 963 | if len(args.placeholder_tokens) != 0: |
| 963 | params_to_optimize.append({ | 964 | params_to_optimize.append({ |
| 964 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 965 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 965 | "lr": learning_rate_emb, | 966 | "lr": learning_rate_emb, |
| 966 | "weight_decay": 0, | 967 | "weight_decay": 0, |
| 967 | }) | 968 | }) |
diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -18,7 +18,6 @@ import transformers | |||
| 18 | 18 | ||
| 19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 22 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 23 | from training.strategy.ti import textual_inversion_strategy | 22 | from training.strategy.ti import textual_inversion_strategy |
| 24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
| @@ -354,7 +353,7 @@ def parse_args(): | |||
| 354 | parser.add_argument( | 353 | parser.add_argument( |
| 355 | "--optimizer", | 354 | "--optimizer", |
| 356 | type=str, | 355 | type=str, |
| 357 | default="dadan", | 356 | default="adan", |
| 358 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 357 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 359 | help='Optimizer to use' | 358 | help='Optimizer to use' |
| 360 | ) | 359 | ) |
| @@ -379,7 +378,7 @@ def parse_args(): | |||
| 379 | parser.add_argument( | 378 | parser.add_argument( |
| 380 | "--adam_weight_decay", | 379 | "--adam_weight_decay", |
| 381 | type=float, | 380 | type=float, |
| 382 | default=0, | 381 | default=2e-2, |
| 383 | help="Weight decay to use." | 382 | help="Weight decay to use." |
| 384 | ) | 383 | ) |
| 385 | parser.add_argument( | 384 | parser.add_argument( |
| @@ -483,7 +482,19 @@ def parse_args(): | |||
| 483 | help="The weight of prior preservation loss." | 482 | help="The weight of prior preservation loss." |
| 484 | ) | 483 | ) |
| 485 | parser.add_argument( | 484 | parser.add_argument( |
| 486 | "--emb_dropout", | 485 | "--lora_r", |
| 486 | type=int, | ||
| 487 | default=8, | ||
| 488 | help="Lora rank, only used if use_lora is True" | ||
| 489 | ) | ||
| 490 | parser.add_argument( | ||
| 491 | "--lora_alpha", | ||
| 492 | type=int, | ||
| 493 | default=32, | ||
| 494 | help="Lora alpha, only used if use_lora is True" | ||
| 495 | ) | ||
| 496 | parser.add_argument( | ||
| 497 | "--lora_dropout", | ||
| 487 | type=float, | 498 | type=float, |
| 488 | default=0, | 499 | default=0, |
| 489 | help="Embedding dropout probability.", | 500 | help="Embedding dropout probability.", |
| @@ -655,7 +666,11 @@ def main(): | |||
| 655 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
| 656 | 667 | ||
| 657 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 658 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
| 670 | args.lora_r, | ||
| 671 | args.lora_alpha, | ||
| 672 | args.lora_dropout | ||
| 673 | ) | ||
| 659 | 674 | ||
| 660 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 675 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 661 | tokenizer.set_dropout(args.vector_dropout) | 676 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -747,6 +762,7 @@ def main(): | |||
| 747 | timm.optim.Adan, | 762 | timm.optim.Adan, |
| 748 | weight_decay=args.adam_weight_decay, | 763 | weight_decay=args.adam_weight_decay, |
| 749 | eps=args.adam_epsilon, | 764 | eps=args.adam_epsilon, |
| 765 | no_prox=True, | ||
| 750 | ) | 766 | ) |
| 751 | elif args.optimizer == 'lion': | 767 | elif args.optimizer == 'lion': |
| 752 | try: | 768 | try: |
| @@ -914,7 +930,7 @@ def main(): | |||
| 914 | print("") | 930 | print("") |
| 915 | 931 | ||
| 916 | optimizer = create_optimizer( | 932 | optimizer = create_optimizer( |
| 917 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 933 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 918 | lr=learning_rate, | 934 | lr=learning_rate, |
| 919 | ) | 935 | ) |
| 920 | 936 | ||
diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -66,7 +66,12 @@ class TrainingStrategy(): | |||
| 66 | prepare: TrainingStrategyPrepareCallable | 66 | prepare: TrainingStrategyPrepareCallable |
| 67 | 67 | ||
| 68 | 68 | ||
| 69 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 69 | def get_models( |
| 70 | pretrained_model_name_or_path: str, | ||
| 71 | emb_r: int = 8, | ||
| 72 | emb_lora_alpha: int = 8, | ||
| 73 | emb_lora_dropout: float = 0.0 | ||
| 74 | ): | ||
| 70 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 72 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 77 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | |||
| 75 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 76 | pretrained_model_name_or_path, subfolder='scheduler') | 81 | pretrained_model_name_or_path, subfolder='scheduler') |
| 77 | 82 | ||
| 78 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) | 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) |
| 79 | 84 | ||
| 80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 81 | 86 | ||
| @@ -653,6 +658,8 @@ def train_loop( | |||
| 653 | on_checkpoint(global_step, "end") | 658 | on_checkpoint(global_step, "end") |
| 654 | raise KeyboardInterrupt | 659 | raise KeyboardInterrupt |
| 655 | 660 | ||
| 661 | return avg_loss, avg_acc, avg_loss_val, avg_acc_val | ||
| 662 | |||
| 656 | 663 | ||
| 657 | def train( | 664 | def train( |
| 658 | accelerator: Accelerator, | 665 | accelerator: Accelerator, |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
| 93 | if use_emb_decay: | 93 | if use_emb_decay: |
| 94 | params = [ | 94 | params = [ |
| 95 | p | 95 | p |
| 96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 96 | for p in text_encoder.text_model.embeddings.parameters() |
| 97 | if p.grad is not None | 97 | if p.grad is not None |
| 98 | ] | 98 | ] |
| 99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| @@ -180,7 +180,7 @@ def lora_prepare( | |||
| 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
| 182 | 182 | ||
| 183 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 183 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
| 184 | 184 | ||
| 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 186 | 186 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
| 72 | 72 | ||
| 73 | if use_ema: | 73 | if use_ema: |
| 74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
| 75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 75 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
| 77 | power=ema_power, | 77 | power=ema_power, |
| 78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
| @@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( | |||
| 84 | def ema_context(): | 84 | def ema_context(): |
| 85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
| 86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
| 87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() | 87 | text_encoder.text_model.embeddings.token_embedding.parameters() |
| 88 | ) | 88 | ) |
| 89 | else: | 89 | else: |
| 90 | return nullcontext() | 90 | return nullcontext() |
| @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| 111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() |
| 112 | if p.grad is not None | 112 | if p.grad is not None |
| 113 | ] | 113 | ] |
| 114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
| @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
| 120 | 120 | ||
| 121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
| @@ -203,7 +203,6 @@ def textual_inversion_prepare( | |||
| 203 | text_encoder.text_model.encoder.requires_grad_(False) | 203 | text_encoder.text_model.encoder.requires_grad_(False) |
| 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 206 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
| 207 | 206 | ||
| 208 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 207 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 209 | 208 | ||
