summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py76
-rw-r--r--models/lora.py131
-rw-r--r--models/sparse.py66
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py20
-rw-r--r--train_lora.py7
-rw-r--r--train_ti.py28
-rw-r--r--training/functional.py11
-rw-r--r--training/strategy/lora.py4
-rw-r--r--training/strategy/ti.py9
9 files changed, 212 insertions, 140 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9be8256..60c1b20 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -11,49 +11,27 @@ from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig 11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14from models.sparse import PseudoSparseEmbedding 14from models.lora import LoraEmbedding
15
16
17def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
18 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
19
20 if old_num_embeddings == new_num_embeddings:
21 return old_embedding
22
23 n = min(old_num_embeddings, new_num_embeddings)
24
25 new_embedding = nn.Embedding(
26 new_num_embeddings,
27 old_embedding_dim,
28 device=old_embedding.weight.device,
29 dtype=old_embedding.weight.dtype
30 )
31 if initializer_factor is not None:
32 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
33 else:
34 nn.init.zeros_(new_embedding.weight.data)
35 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
36 return new_embedding
37 15
38 16
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 17class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): 18 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0.0):
41 super().__init__(config) 19 super().__init__(config)
42 20
43 self.token_embedding = embeddings.token_embedding
44 self.position_embedding = embeddings.position_embedding 21 self.position_embedding = embeddings.position_embedding
45 self.initializer_factor = config.initializer_factor 22 self.initializer_factor = config.initializer_factor
46 23 self.token_embedding = LoraEmbedding(
47 self.token_override_embedding = PseudoSparseEmbedding( 24 self.token_embedding.num_embeddings,
48 self.token_embedding.embedding_dim, 25 self.token_embedding.embedding_dim,
49 dropout_p=dropout_p, 26 r,
50 device=self.token_embedding.weight.device, 27 lora_alpha,
51 dtype=self.token_embedding.weight.dtype, 28 lora_dropout,
52 ) 29 )
53 30
31 self.token_embedding.weight = embeddings.token_embedding.weight
32
54 def resize(self, size: int): 33 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 34 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor)
56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
57 35
58 def add_embed( 36 def add_embed(
59 self, 37 self,
@@ -87,7 +65,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 65 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 66
89 self.token_embedding.weight.data[token_ids] = initializer 67 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids, initializer)
91 68
92 def load_embed(self, input_ids: list[int], filename: Path): 69 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 70 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -97,26 +74,14 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
97 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
98 75
99 def persist(self): 76 def persist(self):
100 input_ids = torch.arange( 77 self.token_embedding.eval()
101 self.token_embedding.num_embeddings, 78 self.token_embedding.merged = False
102 device=self.token_override_embedding.mapping.device
103 )
104 embs, mask = self.token_override_embedding(input_ids)
105 if embs is not None:
106 input_ids = input_ids[mask]
107 self.token_embedding.weight.data[input_ids] = embs
108 self.token_override_embedding.unset(input_ids)
109 79
110 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
111 if isinstance(input_ids, list): 81 if isinstance(input_ids, list):
112 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 82 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
113 83
114 embs = self.token_embedding(input_ids) 84 return self.token_embedding(input_ids)
115 embs_override, mask = self.token_override_embedding(input_ids)
116 if embs_override is not None:
117 embs[mask] = embs_override
118
119 return embs
120 85
121 def forward( 86 def forward(
122 self, 87 self,
@@ -138,7 +103,18 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
138 return embeddings 103 return embeddings
139 104
140 105
141def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: 106def patch_managed_embeddings(
142 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) 107 text_encoder: CLIPTextModel,
108 r: int = 8,
109 lora_alpha: int = 8,
110 lora_dropout: float = 0.0
111) -> ManagedCLIPTextEmbeddings:
112 text_embeddings = ManagedCLIPTextEmbeddings(
113 text_encoder.config,
114 text_encoder.text_model.embeddings,
115 r,
116 lora_alpha,
117 lora_dropout
118 )
143 text_encoder.text_model.embeddings = text_embeddings 119 text_encoder.text_model.embeddings = text_embeddings
144 return text_embeddings 120 return text_embeddings
diff --git a/models/lora.py b/models/lora.py
new file mode 100644
index 0000000..c0f74a6
--- /dev/null
+++ b/models/lora.py
@@ -0,0 +1,131 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6
7
8class LoraLayer():
9 def __init__(
10 self,
11 r: int,
12 lora_alpha: int,
13 lora_dropout: float,
14 merge_weights: bool,
15 ):
16 self.r = r
17 self.lora_alpha = lora_alpha
18 self.lora_dropout_p = lora_dropout
19
20 if lora_dropout > 0.:
21 self.lora_dropout = nn.Dropout(p=lora_dropout)
22 else:
23 self.lora_dropout = nn.Identity()
24
25 self.merged = False
26 self.merge_weights = merge_weights
27
28
29class LoraEmbedding(nn.Embedding, LoraLayer):
30 def __init__(
31 self,
32 num_embeddings: int,
33 embedding_dim: int,
34 r: int = 0,
35 lora_alpha: int = 1,
36 lora_dropout: float = 0.0,
37 merge_weights: bool = True,
38 **kwargs
39 ):
40 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
41 LoraLayer.__init__(
42 self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
43 )
44
45 self.register_buffer('trainable_ids', torch.zeros(num_embeddings, device=self.weight.device, dtype=torch.long))
46 self.trainable_ids -= 1
47
48 if r > 0:
49 self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0)))
50 self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
51 self.scaling = self.lora_alpha / self.r
52 self.weight.requires_grad = False
53
54 self.reset_parameters()
55
56 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None):
57 n = min(self.num_embeddings, new_num_embeddings)
58
59 new_emb = LoraEmbedding(
60 new_num_embeddings,
61 self.embedding_dim,
62 self.r,
63 self.lora_alpha,
64 self.lora_dropout_p,
65 device=self.weight.device,
66 dtype=self.weight.dtype
67 )
68 if initializer_factor is not None:
69 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
70 else:
71 nn.init.zeros_(new_emb.weight.data)
72 new_emb.weight.data[:n, :] = self.weight.data[:n, :]
73 new_emb.lora_A = self.lora_A
74 new_emb.lora_B = self.lora_B
75 new_emb.trainable_ids[:n] = self.trainable_ids[:n]
76
77 return new_emb
78
79 def mark_trainable(self, input_ids):
80 trainable_ids = self.trainable_ids[input_ids]
81 new_ids = trainable_ids[trainable_ids == -1]
82
83 if new_ids.shape[0] == 0:
84 return
85
86 n = self.trainable_ids.shape[0]
87 self.trainable_ids[new_ids] = torch.arange(n, n + new_ids.shape[0])
88
89 lora_A = nn.Parameter(self.weight.new_zeros((self.trainable_ids.shape[0], 0)))
90 lora_A.data[:n] = self.lora_A.data
91 self.lora_A = lora_A
92
93 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self)
95 if hasattr(self, 'lora_A'):
96 nn.init.zeros_(self.lora_A)
97 nn.init.normal_(self.lora_B)
98
99 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode)
101 if self.merge_weights and self.merged:
102 if self.r > 0:
103 mask = ~(self.trainable_ids == -1)
104 trainable_ids = self.trainable_ids[mask]
105 self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling
106 self.merged = False
107
108 def eval(self):
109 nn.Embedding.eval(self)
110 if self.merge_weights and not self.merged:
111 if self.r > 0:
112 mask = ~(self.trainable_ids == -1)
113 trainable_ids = self.trainable_ids[mask]
114 self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling
115 self.merged = True
116
117 def forward(self, input_ids: torch.Tensor):
118 result = nn.Embedding.forward(self, input_ids)
119
120 if self.r > 0 and not self.merged:
121 trainable_ids = self.trainable_ids[input_ids]
122 mask = ~(trainable_ids == -1)
123 trainable_ids = trainable_ids[mask]
124
125 after_A = F.embedding(
126 trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
127 self.norm_type, self.scale_grad_by_freq, self.sparse
128 )
129 result[mask] += (after_A @ self.lora_B.T) * self.scaling
130
131 return result
diff --git a/models/sparse.py b/models/sparse.py
deleted file mode 100644
index 07b3413..0000000
--- a/models/sparse.py
+++ /dev/null
@@ -1,66 +0,0 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6
7class PseudoSparseEmbedding(nn.Module):
8 def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32):
9 super().__init__()
10
11 self.embedding_dim = embedding_dim
12 self.dtype = dtype
13 self.params = nn.ParameterList()
14
15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p)
17 else:
18 self.dropout = nn.Identity()
19
20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
21
22 def forward(self, input_ids: torch.LongTensor):
23 input_ids = input_ids.to(self.mapping.device)
24 ids = self.mapping[input_ids]
25 mask = ~(ids == -1)
26
27 if torch.all(~mask):
28 embs = None
29 else:
30 embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]]))
31
32 return embs, mask
33
34 def resize(self, new_num_embeddings: int):
35 old_num_embeddings = self.mapping.shape[0]
36 n = min(old_num_embeddings, new_num_embeddings)
37
38 new_mapping = torch.zeros(new_num_embeddings, device=self.mapping.device, dtype=torch.long) - 1
39 new_mapping[:n] = self.mapping[:n]
40
41 self.mapping = new_mapping
42
43 def set(self, input_ids: torch.LongTensor, tensor: Optional[torch.Tensor] = None):
44 if len(input_ids.shape) != 0:
45 if tensor is not None:
46 return [self.set(id, t) for id, t in zip(input_ids, tensor)]
47 else:
48 return [self.set(id) for id in input_ids]
49
50 if tensor is None:
51 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
52
53 if tensor.shape[-1] != self.embedding_dim:
54 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
55
56 id = self.mapping[input_ids]
57
58 if id == -1:
59 id = len(self.params)
60 self.mapping[input_ids] = id
61 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
62
63 self.params[id] = tensor
64
65 def unset(self, input_ids: torch.LongTensor):
66 self.mapping[input_ids] = -1
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 13ea2ac..a0dff54 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -591,15 +591,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
591 if callback is not None and i % callback_steps == 0: 591 if callback is not None and i % callback_steps == 0:
592 callback(i, t, latents) 592 callback(i, t, latents)
593 593
594 # 9. Post-processing
595 image = self.decode_latents(latents)
596
597 # 10. Run safety checker
598 has_nsfw_concept = None 594 has_nsfw_concept = None
599 595
600 # 11. Convert to PIL 596 if output_type == "latent":
601 if output_type == "pil": 597 image = latents
598 elif output_type == "pil":
599 # 9. Post-processing
600 image = self.decode_latents(latents)
601
602 # 10. Convert to PIL
602 image = self.numpy_to_pil(image) 603 image = self.numpy_to_pil(image)
604 else:
605 # 9. Post-processing
606 image = self.decode_latents(latents)
607
608 # Offload last model to CPU
609 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
610 self.final_offload_hook.offload()
603 611
604 if not return_dict: 612 if not return_dict:
605 return (image, has_nsfw_concept) 613 return (image, has_nsfw_concept)
diff --git a/train_lora.py b/train_lora.py
index b8c7396..91bda5c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -387,7 +387,7 @@ def parse_args():
387 parser.add_argument( 387 parser.add_argument(
388 "--optimizer", 388 "--optimizer",
389 type=str, 389 type=str,
390 default="dadan", 390 default="adan",
391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
392 help='Optimizer to use' 392 help='Optimizer to use'
393 ) 393 )
@@ -412,7 +412,7 @@ def parse_args():
412 parser.add_argument( 412 parser.add_argument(
413 "--adam_weight_decay", 413 "--adam_weight_decay",
414 type=float, 414 type=float,
415 default=1e-2, 415 default=2e-2,
416 help="Weight decay to use." 416 help="Weight decay to use."
417 ) 417 )
418 parser.add_argument( 418 parser.add_argument(
@@ -780,6 +780,7 @@ def main():
780 timm.optim.Adan, 780 timm.optim.Adan,
781 weight_decay=args.adam_weight_decay, 781 weight_decay=args.adam_weight_decay,
782 eps=args.adam_epsilon, 782 eps=args.adam_epsilon,
783 no_prox=True,
783 ) 784 )
784 elif args.optimizer == 'lion': 785 elif args.optimizer == 'lion':
785 try: 786 try:
@@ -961,7 +962,7 @@ def main():
961 962
962 if len(args.placeholder_tokens) != 0: 963 if len(args.placeholder_tokens) != 0:
963 params_to_optimize.append({ 964 params_to_optimize.append({
964 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), 965 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
965 "lr": learning_rate_emb, 966 "lr": learning_rate_emb,
966 "weight_decay": 0, 967 "weight_decay": 0,
967 }) 968 })
diff --git a/train_ti.py b/train_ti.py
index d931db6..6c57f4b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -18,7 +18,6 @@ import transformers
18 18
19from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from models.convnext.discriminator import ConvNeXtDiscriminator
22from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
23from training.strategy.ti import textual_inversion_strategy 22from training.strategy.ti import textual_inversion_strategy
24from training.optimization import get_scheduler 23from training.optimization import get_scheduler
@@ -354,7 +353,7 @@ def parse_args():
354 parser.add_argument( 353 parser.add_argument(
355 "--optimizer", 354 "--optimizer",
356 type=str, 355 type=str,
357 default="dadan", 356 default="adan",
358 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 357 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
359 help='Optimizer to use' 358 help='Optimizer to use'
360 ) 359 )
@@ -379,7 +378,7 @@ def parse_args():
379 parser.add_argument( 378 parser.add_argument(
380 "--adam_weight_decay", 379 "--adam_weight_decay",
381 type=float, 380 type=float,
382 default=0, 381 default=2e-2,
383 help="Weight decay to use." 382 help="Weight decay to use."
384 ) 383 )
385 parser.add_argument( 384 parser.add_argument(
@@ -483,7 +482,19 @@ def parse_args():
483 help="The weight of prior preservation loss." 482 help="The weight of prior preservation loss."
484 ) 483 )
485 parser.add_argument( 484 parser.add_argument(
486 "--emb_dropout", 485 "--lora_r",
486 type=int,
487 default=8,
488 help="Lora rank, only used if use_lora is True"
489 )
490 parser.add_argument(
491 "--lora_alpha",
492 type=int,
493 default=32,
494 help="Lora alpha, only used if use_lora is True"
495 )
496 parser.add_argument(
497 "--lora_dropout",
487 type=float, 498 type=float,
488 default=0, 499 default=0,
489 help="Embedding dropout probability.", 500 help="Embedding dropout probability.",
@@ -655,7 +666,11 @@ def main():
655 save_args(output_dir, args) 666 save_args(output_dir, args)
656 667
657 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 668 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
658 args.pretrained_model_name_or_path, args.emb_dropout) 669 args.pretrained_model_name_or_path,
670 args.lora_r,
671 args.lora_alpha,
672 args.lora_dropout
673 )
659 674
660 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 675 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
661 tokenizer.set_dropout(args.vector_dropout) 676 tokenizer.set_dropout(args.vector_dropout)
@@ -747,6 +762,7 @@ def main():
747 timm.optim.Adan, 762 timm.optim.Adan,
748 weight_decay=args.adam_weight_decay, 763 weight_decay=args.adam_weight_decay,
749 eps=args.adam_epsilon, 764 eps=args.adam_epsilon,
765 no_prox=True,
750 ) 766 )
751 elif args.optimizer == 'lion': 767 elif args.optimizer == 'lion':
752 try: 768 try:
@@ -914,7 +930,7 @@ def main():
914 print("") 930 print("")
915 931
916 optimizer = create_optimizer( 932 optimizer = create_optimizer(
917 text_encoder.text_model.embeddings.token_override_embedding.parameters(), 933 text_encoder.text_model.embeddings.token_embedding.parameters(),
918 lr=learning_rate, 934 lr=learning_rate,
919 ) 935 )
920 936
diff --git a/training/functional.py b/training/functional.py
index 54bbe78..1fdfdc8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -66,7 +66,12 @@ class TrainingStrategy():
66 prepare: TrainingStrategyPrepareCallable 66 prepare: TrainingStrategyPrepareCallable
67 67
68 68
69def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): 69def get_models(
70 pretrained_model_name_or_path: str,
71 emb_r: int = 8,
72 emb_lora_alpha: int = 8,
73 emb_lora_dropout: float = 0.0
74):
70 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 75 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
71 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 76 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
72 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 77 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -75,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0):
75 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 80 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
76 pretrained_model_name_or_path, subfolder='scheduler') 81 pretrained_model_name_or_path, subfolder='scheduler')
77 82
78 embeddings = patch_managed_embeddings(text_encoder, emb_dropout) 83 embeddings = patch_managed_embeddings(text_encoder, emb_r, emb_lora_alpha, emb_lora_dropout)
79 84
80 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 85 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
81 86
@@ -653,6 +658,8 @@ def train_loop(
653 on_checkpoint(global_step, "end") 658 on_checkpoint(global_step, "end")
654 raise KeyboardInterrupt 659 raise KeyboardInterrupt
655 660
661 return avg_loss, avg_acc, avg_loss_val, avg_acc_val
662
656 663
657def train( 664def train(
658 accelerator: Accelerator, 665 accelerator: Accelerator,
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1517ee8..48236fb 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -93,7 +93,7 @@ def lora_strategy_callbacks(
93 if use_emb_decay: 93 if use_emb_decay:
94 params = [ 94 params = [
95 p 95 p
96 for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() 96 for p in text_encoder.text_model.embeddings.parameters()
97 if p.grad is not None 97 if p.grad is not None
98 ] 98 ]
99 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
@@ -180,7 +180,7 @@ def lora_prepare(
180 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 180 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
181 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 181 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
182 182
183 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) 183 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
184 184
185 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 185 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
186 186
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index ca7cc3d..49236c6 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks(
72 72
73 if use_ema: 73 if use_ema:
74 ema_embeddings = EMAModel( 74 ema_embeddings = EMAModel(
75 text_encoder.text_model.embeddings.token_override_embedding.parameters(), 75 text_encoder.text_model.embeddings.token_embedding.parameters(),
76 inv_gamma=ema_inv_gamma, 76 inv_gamma=ema_inv_gamma,
77 power=ema_power, 77 power=ema_power,
78 max_value=ema_max_decay, 78 max_value=ema_max_decay,
@@ -84,7 +84,7 @@ def textual_inversion_strategy_callbacks(
84 def ema_context(): 84 def ema_context():
85 if ema_embeddings is not None: 85 if ema_embeddings is not None:
86 return ema_embeddings.apply_temporary( 86 return ema_embeddings.apply_temporary(
87 text_encoder.text_model.embeddings.token_override_embedding.parameters() 87 text_encoder.text_model.embeddings.token_embedding.parameters()
88 ) 88 )
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
@@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks(
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
111 for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() 111 for p in text_encoder.text_model.embeddings.token_embedding.parameters()
112 if p.grad is not None 112 if p.grad is not None
113 ] 113 ]
114 return torch.stack(params) if len(params) != 0 else None 114 return torch.stack(params) if len(params) != 0 else None
@@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters())
120 120
121 if use_emb_decay and w is not None: 121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"] 122 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
@@ -203,7 +203,6 @@ def textual_inversion_prepare(
203 text_encoder.text_model.encoder.requires_grad_(False) 203 text_encoder.text_model.encoder.requires_grad_(False)
204 text_encoder.text_model.final_layer_norm.requires_grad_(False) 204 text_encoder.text_model.final_layer_norm.requires_grad_(False)
205 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 205 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
206 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
207 206
208 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 207 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
209 208