summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py15
-rw-r--r--models/sparse.py14
-rw-r--r--train_ti.py24
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/ti.py22
5 files changed, 57 insertions, 22 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index a356434..63a141f 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
37 37
38 38
39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 39class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): 40 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
41 super().__init__(config) 41 super().__init__(config)
42 42
43 self.token_embedding = embeddings.token_embedding 43 self.token_embedding = embeddings.token_embedding
@@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
49 device=self.token_embedding.weight.device, 49 device=self.token_embedding.weight.device,
50 dtype=self.token_embedding.weight.dtype, 50 dtype=self.token_embedding.weight.dtype,
51 ) 51 )
52 self.alpha = alpha
53 52
54 def resize(self, size: int): 53 def resize(self, size: int):
55 self.token_override_embedding.resize(size) 54 self.token_override_embedding.resize(size)
@@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
87 token_ids = torch.tensor(token_ids, dtype=torch.long) 86 token_ids = torch.tensor(token_ids, dtype=torch.long)
88 87
89 self.token_embedding.weight.data[token_ids] = initializer 88 self.token_embedding.weight.data[token_ids] = initializer
90 self.token_override_embedding.set(token_ids) 89 self.token_override_embedding.set(token_ids, initializer)
91 90
92 def load_embed(self, input_ids: list[int], filename: Path): 91 def load_embed(self, input_ids: list[int], filename: Path):
93 with safe_open(filename, framework="pt", device="cpu") as file: 92 with safe_open(filename, framework="pt", device="cpu") as file:
@@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
101 embs, mask = self.token_override_embedding(input_ids) 100 embs, mask = self.token_override_embedding(input_ids)
102 if embs is not None: 101 if embs is not None:
103 input_ids = input_ids[mask] 102 input_ids = input_ids[mask]
104 self.token_embedding.weight.data[input_ids] += self.alpha * embs 103 self.token_embedding.weight.data[input_ids] = embs
105 self.token_override_embedding.unset(input_ids) 104 self.token_override_embedding.unset(input_ids)
106 105
107 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 106 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
108 if isinstance(input_ids, list): 107 if isinstance(input_ids, list):
@@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
111 embs = self.token_embedding(input_ids) 110 embs = self.token_embedding(input_ids)
112 embs_override, mask = self.token_override_embedding(input_ids) 111 embs_override, mask = self.token_override_embedding(input_ids)
113 if embs_override is not None: 112 if embs_override is not None:
114 embs[mask] += self.alpha * embs_override 113 embs[mask] = embs_override
115 114
116 return embs 115 return embs
117 116
@@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
135 return embeddings 134 return embeddings
136 135
137 136
138def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: 137def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
139 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) 138 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
140 text_encoder.text_model.embeddings = text_embeddings 139 text_encoder.text_model.embeddings = text_embeddings
141 return text_embeddings 140 return text_embeddings
diff --git a/models/sparse.py b/models/sparse.py
index 0b15454..8910316 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module):
13 self.params = nn.ParameterList() 13 self.params = nn.ParameterList()
14 self.mapping = torch.zeros(0, device=device, dtype=torch.long) 14 self.mapping = torch.zeros(0, device=device, dtype=torch.long)
15 15
16 def forward(self, input_ids: Optional[torch.LongTensor] = None): 16 def forward(self, input_ids: torch.LongTensor):
17 if input_ids is None:
18 input_ids = torch.arange(self.mapping.shape[0])
19
20 ids = self.mapping[input_ids.to(self.mapping.device)] 17 ids = self.mapping[input_ids.to(self.mapping.device)]
21 mask = ~(ids == -1) 18 mask = ~(ids == -1)
22 19
@@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module):
43 else: 40 else:
44 return [self.set(id) for id in input_ids] 41 return [self.set(id) for id in input_ids]
45 42
43 if tensor is None:
44 tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
45
46 if tensor.shape[-1] != self.embedding_dim:
47 raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]")
48
46 id = self.mapping[input_ids] 49 id = self.mapping[input_ids]
47 50
48 if id == -1: 51 if id == -1:
@@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module):
50 self.mapping[input_ids] = id 53 self.mapping[input_ids] = id
51 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) 54 self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype))
52 55
53 self.params[id] = tensor if tensor is not None else torch.zeros( 56 self.params[id] = tensor
54 self.embedding_dim, device=self.mapping.device, dtype=self.dtype)
55 57
56 def unset(self, input_ids: torch.LongTensor): 58 def unset(self, input_ids: torch.LongTensor):
57 self.mapping[input_ids] = -1 59 self.mapping[input_ids] = -1
diff --git a/train_ti.py b/train_ti.py
index a9a2333..4366c9e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=1e-2, 356 default=0,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,10 +451,21 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--emb_alpha", 454 "--use_emb_decay",
455 default=1.0, 455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e+2,
456 type=float, 467 type=float,
457 help="Embedding alpha." 468 help="Embedding decay factor."
458 ) 469 )
459 parser.add_argument( 470 parser.add_argument(
460 "--noise_timesteps", 471 "--noise_timesteps",
@@ -600,7 +611,7 @@ def main():
600 save_args(output_dir, args) 611 save_args(output_dir, args)
601 612
602 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 613 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
603 args.pretrained_model_name_or_path, args.emb_alpha) 614 args.pretrained_model_name_or_path)
604 615
605 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 616 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
606 tokenizer.set_dropout(args.vector_dropout) 617 tokenizer.set_dropout(args.vector_dropout)
@@ -744,6 +755,9 @@ def main():
744 tokenizer=tokenizer, 755 tokenizer=tokenizer,
745 sample_scheduler=sample_scheduler, 756 sample_scheduler=sample_scheduler,
746 checkpoint_output_dir=checkpoint_output_dir, 757 checkpoint_output_dir=checkpoint_output_dir,
758 use_emb_decay=args.use_emb_decay,
759 emb_decay_target=args.emb_decay_target,
760 emb_decay=args.emb_decay,
747 use_ema=args.use_ema, 761 use_ema=args.use_ema,
748 ema_inv_gamma=args.ema_inv_gamma, 762 ema_inv_gamma=args.ema_inv_gamma,
749 ema_power=args.ema_power, 763 ema_power=args.ema_power,
diff --git a/training/functional.py b/training/functional.py
index 1d8e2ee..96ecbc1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols):
73 return grid 73 return grid
74 74
75 75
76def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): 76def get_models(pretrained_model_name_or_path: str):
77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0):
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 82 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 83 pretrained_model_name_or_path, subfolder='scheduler')
84 84
85 embeddings = patch_managed_embeddings(text_encoder, emb_alpha) 85 embeddings = patch_managed_embeddings(text_encoder)
86 86
87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
88 88
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 95128da..9df160a 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks(
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 use_emb_decay: bool = False,
35 emb_decay_target: float = 0.4,
36 emb_decay: float = 1e-2,
34 use_ema: bool = False, 37 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 38 ema_inv_gamma: float = 1.0,
36 ema_power: int = 1, 39 ema_power: int = 1,
@@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks(
102 yield 105 yield
103 106
104 @torch.no_grad() 107 @torch.no_grad()
105 def on_after_optimize(zero_ids, lr: float): 108 def on_before_optimize(lr: float, epoch: int):
109 if use_emb_decay:
110 return torch.stack([
111 p
112 for p in text_encoder.text_model.embeddings.token_override_embedding.params
113 if p.grad is not None
114 ])
115
116 @torch.no_grad()
117 def on_after_optimize(w, lr: float):
106 if ema_embeddings is not None: 118 if ema_embeddings is not None:
107 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
108 120
121 if use_emb_decay:
122 lambda_ = emb_decay * lr
123
124 if lambda_ != 0:
125 norm = w[:, :].norm(dim=-1, keepdim=True)
126 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
127
109 def on_log(): 128 def on_log():
110 if ema_embeddings is not None: 129 if ema_embeddings is not None:
111 return {"ema_decay": ema_embeddings.decay} 130 return {"ema_decay": ema_embeddings.decay}
@@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks(
149 on_accum_model=on_accum_model, 168 on_accum_model=on_accum_model,
150 on_train=on_train, 169 on_train=on_train,
151 on_eval=on_eval, 170 on_eval=on_eval,
171 on_before_optimize=on_before_optimize,
152 on_after_optimize=on_after_optimize, 172 on_after_optimize=on_after_optimize,
153 on_log=on_log, 173 on_log=on_log,
154 on_checkpoint=on_checkpoint, 174 on_checkpoint=on_checkpoint,