diff options
author | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-15 13:11:11 +0200 |
commit | 99b4dba56e3e1e434820d1221d561e90f1a6d30a (patch) | |
tree | 717a4099e9ebfedec702060fed5ed12aaceb0094 | |
parent | Added cycle LR decay (diff) | |
download | textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.gz textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.tar.bz2 textual-inversion-diff-99b4dba56e3e1e434820d1221d561e90f1a6d30a.zip |
TI via LoRA
-rw-r--r-- | models/clip/embeddings.py | 76 | ||||
-rw-r--r-- | models/lora.py | 131 | ||||
-rw-r--r-- | models/sparse.py | 66 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 20 | ||||
-rw-r--r-- | train_lora.py | 7 | ||||
-rw-r--r-- | train_ti.py | 28 | ||||
-rw-r--r-- | training/functional.py | 11 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 9 |
9 files changed, 212 insertions, 140 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9be8256..60c1b20 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel | |||
11 | from transformers.models.clip import CLIPTextConfig | 11 | from transformers.models.clip import CLIPTextConfig |
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
13 | 13 | ||
14 | from models.sparse import PseudoSparseEmbedding | 14 | from models.lora import LoraEmbedding |
15 | |||
16 | |||
17 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: | ||
18 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | ||
19 | |||
20 | if old_num_embeddings == new_num_embeddings: | ||
21 | return old_embedding | ||
22 | |||
23 | n = min(old_num_embeddings, new_num_embeddings) | ||
24 | |||
25 | new_embedding = nn.Embedding( | ||
26 | new_num_embeddings, | ||
27 | old_embedding_dim, | ||
28 | device=old_embedding.weight.device, | ||
29 | dtype=old_embedding.weight.dtype | ||
30 | ) | ||
31 | if initializer_factor is not None: | ||
32 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
33 | else: | ||
34 | nn.init.zeros_(new_embedding.weight.data) | ||
35 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | ||
36 | return new_embedding | ||
37 | 15 | ||
38 | 16 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 17 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): | 18 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0): |
41 | super().__init__(config) | 19 | super().__init__(config) |
42 | 20 | ||
43 | self.token_embedding = embeddings.token_embedding | ||
44 | self.position_embedding = embeddings.position_embedding | 21 | self.position_embedding = embeddings.position_embedding |
45 | self.initializer_factor = config.initializer_factor | 22 | self.initializer_factor = config.initializer_factor |
46 | 23 | self.token_embedding = LoraEmbedding( | |
47 | self.token_override_embedding = PseudoSparseEmbedding( | 24 | self.token_embedding.num_embeddings, |
48 | self.token_embedding.embedding_dim, | 25 | self.token_embedding.embedding_dim, |
49 | dropout_p=dropout_p, | 26 | r, |
50 | device=self.token_embedding.weight.device, | 27 | lora_alpha, |
51 | dtype=self.token_embedding.weight.dtype, | 28 | lora_dropout, |
52 | ) | 29 | ) |
53 | 30 | ||
31 | self.token_embedding.weight = embeddings.token_embedding.weight | ||
32 | |||
54 | def resize(self, size: int): | 33 | def resize(self, size: int): |
55 | self.token_override_embedding.resize(size) | 34 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) |
56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | ||
57 | 35 | ||
58 | def add_embed( | 36 | def add_embed( |
59 | self, | 37 | self, |
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 65 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 66 | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 67 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids, initializer) | ||
91 | 68 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 69 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 70 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
97 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
98 | 75 | ||
99 | def persist(self): | 76 | def persist(self): |
100 | input_ids = torch.arange( | 77 | self.token_embedding.eval() |
101 | self.token_embedding.num_embeddings, | 78 | self.token_embedding.merged = False |
102 | device=self.token_override_embedding.mapping.device | ||
103 | ) | ||
104 | embs, mask = self.token_override_embedding(input_ids) | ||
105 | if embs is not None: | ||
106 | input_ids = input_ids[mask] | ||
107 | self.token_embedding.weight.data[input_ids] = embs | ||
108 | self.token_override_embedding.unset(input_ids) | ||
109 | 79 | ||
110 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
111 | if isinstance(input_ids, list): | 81 | if isinstance(input_ids, list): |
112 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 82 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
113 | 83 | ||
114 | embs = self.token_embedding(input_ids) | 84 | return self.token_embedding(input_ids) |
115 | embs_override, mask = self.token_override_embedding(input_ids) | ||
116 | if embs_override is not None: | ||
117 | embs[mask] = embs_override | ||
118 | |||
119 | return embs | ||
120 | 85 | ||
121 | def forward( | 86 | def forward( |
122 | self, | 87 | self, |
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
138 | return embeddings | 103 | return embeddings |
139 | 104 | ||
140 | 105 | ||
141 | def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: | 106 | def patch_managed_embeddings( |
142 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) | 107 | text_encoder: CLIPTextModel, |
108 | r: int = 8, | ||
109 | lora_alpha: int = 8, | ||
110 | lora_dropout: float = 0.0 | ||
111 | ) -> ManagedCLIPTextEmbeddings: | ||
112 | text_embeddings = ManagedCLIPTextEmbeddings( | ||
113 | text_encoder.config, | ||
114 | text_encoder.text_model.embeddings, | ||
115 | r, | ||
116 | lora_alpha, | ||
117 | lora_dropout | ||
118 | ) | ||
143 | text_encoder.text_model.embeddings = text_embeddings | 119 | text_encoder.text_model.embeddings = text_embeddings |
144 | return text_embeddings | 120 | return text_embeddings |
diff --git a/models/lora.py b/models/lora.py new file mode 100644 index 0000000..c0f74a6 --- /dev/null +++ b/models/lora.py | |||
@@ -0,0 +1,131 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | import torch.nn.functional as F | ||
6 | |||
7 | |||
8 | class LoraLayer(): | ||
9 | def __init__( | ||
10 | self, | ||
11 | r: int, | ||
12 | lora_alpha: int, | ||
13 | lora_dropout: float, | ||
14 | merge_weights: bool, | ||
15 | ): | ||
16 | self.r = r | ||
17 | self.lora_alpha = lora_alpha | ||
18 | self.lora_dropout_p = lora_dropout | ||
19 | |||
20 | if lora_dropout > 0.: | ||
21 | self.lora_dropout = nn.Dropout(p=lora_dropout) | ||
22 | else: | ||
23 | self.lora_dropout = nn.Identity() | ||
24 | |||
25 | self.merged = False | ||
26 | self.merge_weights = merge_weights | ||
27 | |||
28 | |||
29 | class LoraEmbedding(nn.Embedding, LoraLayer): | ||
30 | def __init__( | ||
31 | self, | ||
32 | num_embeddings: int, | ||
33 | embedding_dim: int, | ||
34 | r: int = 0, | ||
35 | lora_alpha: int = 1, | ||
36 | lora_dropout: float = 0.0, | ||
37 | merge_weights: bool = True, | ||
38 | **kwargs | ||
39 | ): | ||
40 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | ||
41 | LoraLayer.__init__( | ||
42 | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights | ||
43 | ) | ||
44 | |||
45 | self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long)) | ||
46 | self.trainable_ids -= 1 | ||
47 | |||
48 | if r > 0: | ||
49 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) | ||
50 | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | ||
51 | self.scaling = self.lora_alpha / self.r | ||
52 | self.weight.requires_grad = False | ||
53 | |||
54 | self.reset_parameters() | ||
55 | |||
56 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | ||
57 | n = min(self.num_embeddings, new_num_embeddings) | ||
58 | |||
59 | new_emb = LoraEmbedding( | ||
60 | new_num_embeddings, | ||
61 | self.embedding_dim, | ||
62 | self.r, | ||
63 | self.lora_alpha, | ||
64 | self.lora_dropout_p, | ||
65 | device=self.weight.device, | ||
66 | dtype=self.weight.dtype | ||
67 | ) | ||
68 | if initializer_factor is not None: | ||
69 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
70 | else: | ||
71 | nn.init.zeros_(new_emb.weight.data) | ||
72 | new_emb.weight.data[:n, :] = self.weight.data[:n, :] | ||
73 | new_emb.lora_A = self.lora_A | ||
74 | new_emb.lora_B = self.lora_B | ||
75 | new_emb.trainable_ids[:n] = self.trainable_ids[:n] | ||
76 | |||
77 | return new_emb | ||
78 | |||
79 | def mark_trainable(self, input_ids): | ||
80 | trainable_ids = self.trainable_ids[input_ids] | ||
81 | new_ids = trainable_ids[trainable_ids == -1] | ||
82 | |||
83 | if new_ids.shape[0] == 0: | ||
84 | return | ||
85 | |||
86 | n = self.trainable_ids.shape[0] | ||
87 | self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0]) | ||
88 | |||
89 | lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0))) | ||
90 | lora_A.data[:n] = self.lora_A.data | ||
91 | self.lora_A = lora_A | ||
92 | |||
93 | def reset_parameters(self): | ||
94 | nn.Embedding.reset_parameters(self) | ||
95 | if hasattr(self, 'lora_A'): | ||
96 | nn.init.zeros_(self.lora_A) | ||
97 | nn.init.normal_(self.lora_B) | ||
98 | |||
99 | def train(self, mode: bool = True): | ||
100 | nn.Embedding.train(self, mode) | ||
101 | if self.merge_weights and self.merged: | ||
102 | if self.r > 0: | ||
103 | mask = ~(self.trainable_ids == -1) | ||
104 | trainable_ids = self.trainable_ids[mask] | ||
105 | self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling | ||
106 | self.merged = False | ||
107 | |||
108 | def eval(self): | ||
109 | nn.Embedding.eval(self) | ||
110 | if self.merge_weights and not self.merged: | ||
111 | if self.r > 0: | ||
112 | mask = ~(self.trainable_ids == -1) | ||
113 | trainable_ids = self.trainable_ids[mask] | ||
114 | self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling | ||
115 | self.merged = True | ||
116 | |||
117 | def forward(self, input_ids: torch.Tensor): | ||
118 | result = nn.Embedding.forward(self, input_ids) | ||
119 | |||
120 | if self.r > 0 and not self.merged: | ||
121 | trainable_ids = self.trainable_ids[input_ids] | ||
122 | mask = ~(trainable_ids == -1) | ||
123 | trainable_ids = trainable_ids[mask] | ||
124 | |||
125 | after_A = F.embedding( | ||
126 | trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, | ||
127 | self.norm_type, self.scale_grad_by_freq, self.sparse | ||
128 | ) | ||
129 | result[mask] += (after_A @ self.lora_B.T) * self.scaling | ||
130 | |||
131 | return result | ||
diff --git a/models/sparse.py b/models/sparse.py deleted file mode 100644 index 07b3413..0000000 --- a/models/sparse.py +++ /dev/null | |||
@@ -1,66 +0,0 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | |||
7 | class PseudoSparseEmbedding(nn.Module): | ||
8 | def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): | ||
9 | super().__init__() | ||
10 | |||
11 | self.embedding_dim = embedding_dim | ||
12 | self.dtype = dtype | ||
13 | self.params = nn.ParameterList() | ||
14 | |||
15 | if dropout_p > 0.0: | ||
16 | self.dropout = nn.Dropout(p=dropout_p) | ||
17 | else: | ||
18 | self.dropout = nn.Identity() | ||
19 | |||
20 | self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) | ||
21 | |||
22 | def forward(self, input_ids: torch.LongTensor): | ||
23 | input_ids = input_ids.to(self.mapping.device) | ||
24 | ids = self.mapping[input_ids] | ||
25 | mask = ~(ids == -1) | ||
26 | |||
27 | if torch.all(~mask): | ||
28 | embs = None | ||
29 | else: | ||
30 | embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) | ||
31 | |||
32 | return embs, mask | ||
33 | |||
34 | def resize(self, new_num_embeddings: int): | ||
35 | old_num_embeddings = self.mapping.shape[0] | ||
36 | n = min(old_num_embeddings, new_num_embeddings) | ||
37 | |||
38 | new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1 | ||
39 | new_mapping[:n] = self.mapping[:n] | ||
40 | |||
41 | self.mapping = new_mapping | ||
42 | |||
43 | def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None): | ||
44 | if len(input_ids.shape) != 0: | ||
45 | if tensor is not None: | ||
46 | return [self.set(id, t) for id, t in zip(input_ids, tensor)] | ||
47 | else: | ||
48 | return [self.set(id) for id in input_ids] | ||
49 | |||
50 | if tensor is None: | ||
51 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
52 | |||
53 | if tensor.shape[-1] != self.embedding_dim: | ||
54 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
55 | |||
56 | id = self.mapping[input_ids] | ||
57 | |||
58 | if id == -1: | ||
59 | id = len(self.params) | ||
60 | self.mapping[input_ids] = id | ||
61 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | ||
62 | |||
63 | self.params[id] = tensor | ||
64 | |||
65 | def unset(self, input_ids: torch.LongTensor): | ||
66 | self.mapping[input_ids] = -1 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 13ea2ac..a0dff54 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
591 | if callback is not None and i % callback_steps == 0: | 591 | if callback is not None and i % callback_steps == 0: |
592 | callback(i, t, latents) | 592 | callback(i, t, latents) |
593 | 593 | ||
594 | # 9. Post-processing | ||
595 | image = self.decode_latents(latents) | ||
596 | |||
597 | # 10. Run safety checker | ||
598 | has_nsfw_concept = None | 594 | has_nsfw_concept = None |
599 | 595 | ||
600 | # 11. Convert to PIL | 596 | if output_type == "latent": |
601 | if output_type == "pil": | 597 | image = latents |
598 | elif output_type == "pil": | ||
599 | # 9. Post-processing | ||
600 | image = self.decode_latents(latents) | ||
601 | |||
602 | # 10. Convert to PIL | ||
602 | image = self.numpy_to_pil(image) | 603 | image = self.numpy_to_pil(image) |
604 | else: | ||
605 | # 9. Post-processing | ||
606 | image = self.decode_latents(latents) | ||
607 | |||
608 | # Offload last model to CPU | ||
609 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | ||
610 | self.final_offload_hook.offload() | ||
603 | 611 | ||
604 | if not return_dict: | 612 | if not return_dict: |
605 | return (image, has_nsfw_concept) | 613 | return (image, has_nsfw_concept) |
diff --git a/train_lora.py b/train_lora.py index b8c7396..91bda5c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -387,7 +387,7 @@ def parse_args(): | |||
387 | parser.add_argument( | 387 | parser.add_argument( |
388 | "--optimizer", | 388 | "--optimizer", |
389 | type=str, | 389 | type=str, |
390 | default="dadan", | 390 | default="adan", |
391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
392 | help='Optimizer to use' | 392 | help='Optimizer to use' |
393 | ) | 393 | ) |
@@ -412,7 +412,7 @@ def parse_args(): | |||
412 | parser.add_argument( | 412 | parser.add_argument( |
413 | "--adam_weight_decay", | 413 | "--adam_weight_decay", |
414 | type=float, | 414 | type=float, |
415 | default=1e-2, | 415 | default=2e-2, |
416 | help="Weight decay to use." | 416 | help="Weight decay to use." |
417 | ) | 417 | ) |
418 | parser.add_argument( | 418 | parser.add_argument( |
@@ -780,6 +780,7 @@ def main(): | |||
780 | timm.optim.Adan, | 780 | timm.optim.Adan, |
781 | weight_decay=args.adam_weight_decay, | 781 | weight_decay=args.adam_weight_decay, |
782 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
783 | no_prox=True, | ||
783 | ) | 784 | ) |
784 | elif args.optimizer == 'lion': | 785 | elif args.optimizer == 'lion': |
785 | try: | 786 | try: |
@@ -961,7 +962,7 @@ def main(): | |||
961 | 962 | ||
962 | if len(args.placeholder_tokens) != 0: | 963 | if len(args.placeholder_tokens) != 0: |
963 | params_to_optimize.append({ | 964 | params_to_optimize.append({ |
964 | "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 965 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
965 | "lr": learning_rate_emb, | 966 | "lr": learning_rate_emb, |
966 | "weight_decay": 0, | 967 | "weight_decay": 0, |
967 | }) | 968 | }) |
diff --git a/train_ti.py b/train_ti.py index d931db6..6c57f4b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -18,7 +18,6 @@ import transformers | |||
18 | 18 | ||
19 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
20 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
22 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
23 | from training.strategy.ti import textual_inversion_strategy | 22 | from training.strategy.ti import textual_inversion_strategy |
24 | from training.optimization import get_scheduler | 23 | from training.optimization import get_scheduler |
@@ -354,7 +353,7 @@ def parse_args(): | |||
354 | parser.add_argument( | 353 | parser.add_argument( |
355 | "--optimizer", | 354 | "--optimizer", |
356 | type=str, | 355 | type=str, |
357 | default="dadan", | 356 | default="adan", |
358 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 357 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
359 | help='Optimizer to use' | 358 | help='Optimizer to use' |
360 | ) | 359 | ) |
@@ -379,7 +378,7 @@ def parse_args(): | |||
379 | parser.add_argument( | 378 | parser.add_argument( |
380 | "--adam_weight_decay", | 379 | "--adam_weight_decay", |
381 | type=float, | 380 | type=float, |
382 | default=0, | 381 | default=2e-2, |
383 | help="Weight decay to use." | 382 | help="Weight decay to use." |
384 | ) | 383 | ) |
385 | parser.add_argument( | 384 | parser.add_argument( |
@@ -483,7 +482,19 @@ def parse_args(): | |||
483 | help="The weight of prior preservation loss." | 482 | help="The weight of prior preservation loss." |
484 | ) | 483 | ) |
485 | parser.add_argument( | 484 | parser.add_argument( |
486 | "--emb_dropout", | 485 | "--lora_r", |
486 | type=int, | ||
487 | default=8, | ||
488 | help="Lora rank, only used if use_lora is True" | ||
489 | ) | ||
490 | parser.add_argument( | ||
491 | "--lora_alpha", | ||
492 | type=int, | ||
493 | default=32, | ||
494 | help="Lora alpha, only used if use_lora is True" | ||
495 | ) | ||
496 | parser.add_argument( | ||
497 | "--lora_dropout", | ||
487 | type=float, | 498 | type=float, |
488 | default=0, | 499 | default=0, |
489 | help="Embedding dropout probability.", | 500 | help="Embedding dropout probability.", |
@@ -655,7 +666,11 @@ def main(): | |||
655 | save_args(output_dir, args) | 666 | save_args(output_dir, args) |
656 | 667 | ||
657 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 668 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
658 | args.pretrained_model_name_or_path, args.emb_dropout) | 669 | args.pretrained_model_name_or_path, |
670 | args.lora_r, | ||
671 | args.lora_alpha, | ||
672 | args.lora_dropout | ||
673 | ) | ||
659 | 674 | ||
660 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 675 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
661 | tokenizer.set_dropout(args.vector_dropout) | 676 | tokenizer.set_dropout(args.vector_dropout) |
@@ -747,6 +762,7 @@ def main(): | |||
747 | timm.optim.Adan, | 762 | timm.optim.Adan, |
748 | weight_decay=args.adam_weight_decay, | 763 | weight_decay=args.adam_weight_decay, |
749 | eps=args.adam_epsilon, | 764 | eps=args.adam_epsilon, |
765 | no_prox=True, | ||
750 | ) | 766 | ) |
751 | elif args.optimizer == 'lion': | 767 | elif args.optimizer == 'lion': |
752 | try: | 768 | try: |
@@ -914,7 +930,7 @@ def main(): | |||
914 | print("") | 930 | print("") |
915 | 931 | ||
916 | optimizer = create_optimizer( | 932 | optimizer = create_optimizer( |
917 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 933 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
918 | lr=learning_rate, | 934 | lr=learning_rate, |
919 | ) | 935 | ) |
920 | 936 | ||
diff --git a/training/functional.py b/training/functional.py index 54bbe78..1fdfdc8 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -66,7 +66,12 @@ class TrainingStrategy(): | |||
66 | prepare: TrainingStrategyPrepareCallable | 66 | prepare: TrainingStrategyPrepareCallable |
67 | 67 | ||
68 | 68 | ||
69 | def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | 69 | def get_models( |
70 | pretrained_model_name_or_path: str, | ||
71 | emb_r: int = 8, | ||
72 | emb_lora_alpha: int = 8, | ||
73 | emb_lora_dropout: float = 0.0 | ||
74 | ): | ||
70 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 75 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
72 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 77 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): | |||
75 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 80 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
76 | pretrained_model_name_or_path, subfolder='scheduler') | 81 | pretrained_model_name_or_path, subfolder='scheduler') |
77 | 82 | ||
78 | embeddings = patch_managed_embeddings(text_encoder, emb_dropout) | 83 | embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout) |
79 | 84 | ||
80 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 85 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
81 | 86 | ||
@@ -653,6 +658,8 @@ def train_loop( | |||
653 | on_checkpoint(global_step, "end") | 658 | on_checkpoint(global_step, "end") |
654 | raise KeyboardInterrupt | 659 | raise KeyboardInterrupt |
655 | 660 | ||
661 | return avg_loss, avg_acc, avg_loss_val, avg_acc_val | ||
662 | |||
656 | 663 | ||
657 | def train( | 664 | def train( |
658 | accelerator: Accelerator, | 665 | accelerator: Accelerator, |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1517ee8..48236fb 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -93,7 +93,7 @@ def lora_strategy_callbacks( | |||
93 | if use_emb_decay: | 93 | if use_emb_decay: |
94 | params = [ | 94 | params = [ |
95 | p | 95 | p |
96 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 96 | for p in text_encoder.text_model.embeddings.parameters() |
97 | if p.grad is not None | 97 | if p.grad is not None |
98 | ] | 98 | ] |
99 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
@@ -180,7 +180,7 @@ def lora_prepare( | |||
180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 180 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 181 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
182 | 182 | ||
183 | text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) | 183 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
184 | 184 | ||
185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 185 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
186 | 186 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ca7cc3d..49236c6 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( | |||
72 | 72 | ||
73 | if use_ema: | 73 | if use_ema: |
74 | ema_embeddings = EMAModel( | 74 | ema_embeddings = EMAModel( |
75 | text_encoder.text_model.embeddings.token_override_embedding.parameters(), | 75 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
76 | inv_gamma=ema_inv_gamma, | 76 | inv_gamma=ema_inv_gamma, |
77 | power=ema_power, | 77 | power=ema_power, |
78 | max_value=ema_max_decay, | 78 | max_value=ema_max_decay, |
@@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks( | |||
84 | def ema_context(): | 84 | def ema_context(): |
85 | if ema_embeddings is not None: | 85 | if ema_embeddings is not None: |
86 | return ema_embeddings.apply_temporary( | 86 | return ema_embeddings.apply_temporary( |
87 | text_encoder.text_model.embeddings.token_override_embedding.parameters() | 87 | text_encoder.text_model.embeddings.token_embedding.parameters() |
88 | ) | 88 | ) |
89 | else: | 89 | else: |
90 | return nullcontext() | 90 | return nullcontext() |
@@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( | |||
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
111 | for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() | 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() |
112 | if p.grad is not None | 112 | if p.grad is not None |
113 | ] | 113 | ] |
114 | return torch.stack(params) if len(params) != 0 else None | 114 | return torch.stack(params) if len(params) != 0 else None |
@@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
120 | 120 | ||
121 | if use_emb_decay and w is not None: | 121 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
@@ -203,7 +203,6 @@ def textual_inversion_prepare( | |||
203 | text_encoder.text_model.encoder.requires_grad_(False) | 203 | text_encoder.text_model.encoder.requires_grad_(False) |
204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 204 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 205 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
206 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
207 | 206 | ||
208 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 207 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
209 | 208 | ||