summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py50
-rw-r--r--train_ti.py37
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/ti.py23
4 files changed, 46 insertions, 68 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1e21965..d8343a0 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -12,7 +12,7 @@ from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings 12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13 13
14 14
15def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: 15def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape 16 old_num_embeddings, old_embedding_dim = old_embedding.weight.shape
17 17
18 if old_num_embeddings == new_num_embeddings: 18 if old_num_embeddings == new_num_embeddings:
@@ -26,13 +26,16 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
26 device=old_embedding.weight.device, 26 device=old_embedding.weight.device,
27 dtype=old_embedding.weight.dtype 27 dtype=old_embedding.weight.dtype
28 ) 28 )
29 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) 29 if initializer_factor is not None:
30 new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
31 else:
32 nn.init.zeros_(new_embedding.weight.data)
30 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] 33 new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :]
31 return new_embedding 34 return new_embedding
32 35
33 36
34class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 37class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
35 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): 38 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0):
36 super().__init__(config) 39 super().__init__(config)
37 40
38 self.token_embedding = embeddings.token_embedding 41 self.token_embedding = embeddings.token_embedding
@@ -40,17 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
40 self.initializer_factor = config.initializer_factor 43 self.initializer_factor = config.initializer_factor
41 self.alpha = alpha 44 self.alpha = alpha
42 45
43 self.temp_token_embedding = nn.Embedding( 46 self.temp_token_embedding = nn.ParameterList()
44 self.token_embedding.num_embeddings,
45 self.token_embedding.embedding_dim,
46 device=self.token_embedding.weight.device,
47 dtype=self.token_embedding.weight.dtype
48 )
49 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
50 self.temp_token_ids = torch.tensor([], dtype=torch.long) 47 self.temp_token_ids = torch.tensor([], dtype=torch.long)
51 48
52 def resize(self, size: int): 49 def resize(self, size: int):
53 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 50 for _ in range(len(self.temp_token_embedding), size):
51 self.temp_token_embedding.append(torch.zeros(
52 self.token_embedding.embedding_dim,
53 device=self.token_embedding.weight.device,
54 dtype=self.token_embedding.weight.dtype,
55 ))
54 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 56 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
55 57
56 def add_embed( 58 def add_embed(
@@ -85,7 +87,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
85 token_ids = torch.tensor(token_ids, dtype=torch.long) 87 token_ids = torch.tensor(token_ids, dtype=torch.long)
86 88
87 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 89 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
88 self.temp_token_embedding.weight.data[token_ids] = initializer
89 self.token_embedding.weight.data[token_ids] = initializer 90 self.token_embedding.weight.data[token_ids] = initializer
90 91
91 def load_embed(self, input_ids: list[int], filename: Path): 92 def load_embed(self, input_ids: list[int], filename: Path):
@@ -96,16 +97,31 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 save_file({"embed": self.get_embed(input_ids)}, filename) 97 save_file({"embed": self.get_embed(input_ids)}, filename)
97 98
98 def persist(self): 99 def persist(self):
99 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 100 for id, emb in zip(self.temp_token_ids, self.temp_token_embedding):
101 self.token_embedding.weight.data[id] += self.alpha * emb
102 nn.init.zeros_(emb)
100 self.temp_token_ids = torch.tensor([], dtype=torch.long) 103 self.temp_token_ids = torch.tensor([], dtype=torch.long)
101 104
102 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 105 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
103 if isinstance(input_ids, list): 106 if isinstance(input_ids, list):
104 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 107 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
105 108
109 all_temp_token_ids = self.temp_token_ids.to(input_ids.device)
110
106 embeds = self.token_embedding(input_ids) 111 embeds = self.token_embedding(input_ids)
107 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 112 mask = torch.isin(input_ids, all_temp_token_ids)
108 embeds[mask] = self.temp_token_embedding(input_ids[mask]) 113 temp_token_ids = input_ids[mask]
114
115 temp_token_ids = temp_token_ids.unsqueeze(1)
116 all_temp_token_ids = all_temp_token_ids.unsqueeze(0)
117 temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze()
118
119 if len(temp_token_ids):
120 embeds_override = torch.stack([
121 self.temp_token_embedding[id]
122 for id in temp_token_ids
123 ])
124 embeds[mask] += self.alpha * embeds_override
109 125
110 return embeds 126 return embeds
111 127
@@ -129,7 +145,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
129 return embeddings 145 return embeddings
130 146
131 147
132def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: 148def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings:
133 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 149 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha)
134 text_encoder.text_model.embeddings = text_embeddings 150 text_encoder.text_model.embeddings = text_embeddings
135 return text_embeddings 151 return text_embeddings
diff --git a/train_ti.py b/train_ti.py
index 8dde1ba..0ad7574 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -353,7 +353,7 @@ def parse_args():
353 parser.add_argument( 353 parser.add_argument(
354 "--adam_weight_decay", 354 "--adam_weight_decay",
355 type=float, 355 type=float,
356 default=0, 356 default=1e-2,
357 help="Weight decay to use." 357 help="Weight decay to use."
358 ) 358 )
359 parser.add_argument( 359 parser.add_argument(
@@ -451,21 +451,10 @@ def parse_args():
451 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
452 ) 452 )
453 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay", 454 "--emb_alpha",
455 action="store_true", 455 default=1.0,
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float, 456 type=float,
468 help="Embedding decay factor." 457 help="Embedding alpha."
469 ) 458 )
470 parser.add_argument( 459 parser.add_argument(
471 "--noise_timesteps", 460 "--noise_timesteps",
@@ -567,16 +556,16 @@ def parse_args():
567 raise ValueError("You must specify --output_dir") 556 raise ValueError("You must specify --output_dir")
568 557
569 if args.adam_beta1 is None: 558 if args.adam_beta1 is None:
570 if args.optimizer in ('adam', 'adam8bit'): 559 if args.optimizer == 'lion':
571 args.adam_beta1 = 0.9
572 elif args.optimizer == 'lion':
573 args.adam_beta1 = 0.95 560 args.adam_beta1 = 0.95
561 else:
562 args.adam_beta1 = 0.9
574 563
575 if args.adam_beta2 is None: 564 if args.adam_beta2 is None:
576 if args.optimizer in ('adam', 'adam8bit'): 565 if args.optimizer == 'lion':
577 args.adam_beta2 = 0.999
578 elif args.optimizer == 'lion':
579 args.adam_beta2 = 0.98 566 args.adam_beta2 = 0.98
567 else:
568 args.adam_beta2 = 0.999
580 569
581 return args 570 return args
582 571
@@ -611,7 +600,7 @@ def main():
611 save_args(output_dir, args) 600 save_args(output_dir, args)
612 601
613 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 602 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
614 args.pretrained_model_name_or_path) 603 args.pretrained_model_name_or_path, args.emb_alpha)
615 604
616 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 605 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
617 tokenizer.set_dropout(args.vector_dropout) 606 tokenizer.set_dropout(args.vector_dropout)
@@ -755,10 +744,6 @@ def main():
755 tokenizer=tokenizer, 744 tokenizer=tokenizer,
756 sample_scheduler=sample_scheduler, 745 sample_scheduler=sample_scheduler,
757 checkpoint_output_dir=checkpoint_output_dir, 746 checkpoint_output_dir=checkpoint_output_dir,
758 gradient_checkpointing=args.gradient_checkpointing,
759 use_emb_decay=args.use_emb_decay,
760 emb_decay_target=args.emb_decay_target,
761 emb_decay=args.emb_decay,
762 use_ema=args.use_ema, 747 use_ema=args.use_ema,
763 ema_inv_gamma=args.ema_inv_gamma, 748 ema_inv_gamma=args.ema_inv_gamma,
764 ema_power=args.ema_power, 749 ema_power=args.ema_power,
diff --git a/training/functional.py b/training/functional.py
index 96ecbc1..1d8e2ee 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols):
73 return grid 73 return grid
74 74
75 75
76def get_models(pretrained_model_name_or_path: str): 76def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0):
77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 77 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 78 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str):
82 sample_scheduler = UniPCMultistepScheduler.from_pretrained( 82 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 83 pretrained_model_name_or_path, subfolder='scheduler')
84 84
85 embeddings = patch_managed_embeddings(text_encoder) 85 embeddings = patch_managed_embeddings(text_encoder, emb_alpha)
86 86
87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 87 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
88 88
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index c7520ed..16baa34 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks(
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False, 34 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1, 36 ema_power: int = 1,
@@ -106,28 +102,10 @@ def textual_inversion_strategy_callbacks(
106 yield 102 yield
107 103
108 @torch.no_grad() 104 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float): 105 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None: 106 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 107 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118 108
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130
131 def on_log(): 109 def on_log():
132 if ema_embeddings is not None: 110 if ema_embeddings is not None:
133 return {"ema_decay": ema_embeddings.decay} 111 return {"ema_decay": ema_embeddings.decay}
@@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks(
171 on_accum_model=on_accum_model, 149 on_accum_model=on_accum_model,
172 on_train=on_train, 150 on_train=on_train,
173 on_eval=on_eval, 151 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize, 152 on_after_optimize=on_after_optimize,
176 on_log=on_log, 153 on_log=on_log,
177 on_checkpoint=on_checkpoint, 154 on_checkpoint=on_checkpoint,