summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py42
-rw-r--r--train_ti.py52
-rw-r--r--training/functional.py2
-rw-r--r--training/strategy/ti.py100
4 files changed, 132 insertions, 64 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index c9c788c..1e21965 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi
31 return new_embedding 31 return new_embedding
32 32
33 33
34class OverlayLinear(nn.Module):
35 def __init__(self, in_features, out_features, rank=4):
36 super().__init__()
37
38 if rank > min(in_features, out_features):
39 raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}")
40
41 self.rank = rank
42 self.down = nn.Linear(in_features, rank, bias=False)
43 self.up = nn.Linear(rank, out_features, bias=False)
44 self.reset()
45
46 def reset(self):
47 nn.init.normal_(self.down.weight, std=1 / self.rank)
48 nn.init.zeros_(self.up.weight)
49
50 def forward(self, hidden_states):
51 orig_dtype = hidden_states.dtype
52 dtype = self.down.weight.dtype
53
54 down_hidden_states = self.down(hidden_states.to(dtype))
55 up_hidden_states = self.up(down_hidden_states)
56
57 return up_hidden_states.to(orig_dtype)
58
59
60class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 34class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
61 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): 35 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4):
62 super().__init__(config) 36 super().__init__(config)
63 37
64 self.token_embedding = embeddings.token_embedding 38 self.token_embedding = embeddings.token_embedding
65 self.position_embedding = embeddings.position_embedding 39 self.position_embedding = embeddings.position_embedding
66 self.initializer_factor = config.initializer_factor 40 self.initializer_factor = config.initializer_factor
41 self.alpha = alpha
67 42
68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
69 self.temp_token_embedding = nn.Embedding( 43 self.temp_token_embedding = nn.Embedding(
70 self.token_embedding.num_embeddings, 44 self.token_embedding.num_embeddings,
71 self.token_embedding.embedding_dim, 45 self.token_embedding.embedding_dim,
@@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
75 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() 49 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
76 self.temp_token_ids = torch.tensor([], dtype=torch.long) 50 self.temp_token_ids = torch.tensor([], dtype=torch.long)
77 51
78 def reset_overlay(self):
79 self.overlay.reset()
80
81 def resize(self, size: int): 52 def resize(self, size: int):
82 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) 53 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
83 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 54 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
@@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
125 save_file({"embed": self.get_embed(input_ids)}, filename) 96 save_file({"embed": self.get_embed(input_ids)}, filename)
126 97
127 def persist(self): 98 def persist(self):
128 embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] 99 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
129 self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds)
130 self.overlay.reset()
131 self.temp_token_ids = torch.tensor([], dtype=torch.long) 100 self.temp_token_ids = torch.tensor([], dtype=torch.long)
132 101
133 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 102 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
@@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
135 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 104 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
136 105
137 embeds = self.token_embedding(input_ids) 106 embeds = self.token_embedding(input_ids)
138
139 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 107 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
140 108 embeds[mask] = self.temp_token_embedding(input_ids[mask])
141 temp_embeds = self.temp_token_embedding(input_ids[mask])
142 embeds[mask] = temp_embeds + self.overlay(temp_embeds)
143 109
144 return embeds 110 return embeds
145 111
diff --git a/train_ti.py b/train_ti.py
index 26ac384..5482326 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,7 +1,6 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
5from functools import partial 4from functools import partial
6from pathlib import Path 5from pathlib import Path
7import math 6import math
@@ -308,6 +307,26 @@ def parse_args():
308 help="Minimum learning rate in the lr scheduler." 307 help="Minimum learning rate in the lr scheduler."
309 ) 308 )
310 parser.add_argument( 309 parser.add_argument(
310 "--use_ema",
311 action="store_true",
312 help="Whether to use EMA model."
313 )
314 parser.add_argument(
315 "--ema_inv_gamma",
316 type=float,
317 default=1.0
318 )
319 parser.add_argument(
320 "--ema_power",
321 type=float,
322 default=4/5
323 )
324 parser.add_argument(
325 "--ema_max_decay",
326 type=float,
327 default=0.9999
328 )
329 parser.add_argument(
311 "--optimizer", 330 "--optimizer",
312 type=str, 331 type=str,
313 default="dadan", 332 default="dadan",
@@ -334,7 +353,7 @@ def parse_args():
334 parser.add_argument( 353 parser.add_argument(
335 "--adam_weight_decay", 354 "--adam_weight_decay",
336 type=float, 355 type=float,
337 default=1e-2, 356 default=0,
338 help="Weight decay to use." 357 help="Weight decay to use."
339 ) 358 )
340 parser.add_argument( 359 parser.add_argument(
@@ -432,6 +451,23 @@ def parse_args():
432 help="The weight of prior preservation loss." 451 help="The weight of prior preservation loss."
433 ) 452 )
434 parser.add_argument( 453 parser.add_argument(
454 "--use_emb_decay",
455 action="store_true",
456 help="Whether to use embedding decay."
457 )
458 parser.add_argument(
459 "--emb_decay_target",
460 default=0.4,
461 type=float,
462 help="Embedding decay target."
463 )
464 parser.add_argument(
465 "--emb_decay",
466 default=1e2,
467 type=float,
468 help="Embedding decay factor."
469 )
470 parser.add_argument(
435 "--noise_timesteps", 471 "--noise_timesteps",
436 type=int, 472 type=int,
437 default=1000, 473 default=1000,
@@ -696,6 +732,13 @@ def main():
696 sample_scheduler=sample_scheduler, 732 sample_scheduler=sample_scheduler,
697 checkpoint_output_dir=checkpoint_output_dir, 733 checkpoint_output_dir=checkpoint_output_dir,
698 gradient_checkpointing=args.gradient_checkpointing, 734 gradient_checkpointing=args.gradient_checkpointing,
735 use_emb_decay=args.use_emb_decay,
736 emb_decay_target=args.emb_decay_target,
737 emb_decay=args.emb_decay,
738 use_ema=args.use_ema,
739 ema_inv_gamma=args.ema_inv_gamma,
740 ema_power=args.ema_power,
741 ema_max_decay=args.ema_max_decay,
699 sample_batch_size=args.sample_batch_size, 742 sample_batch_size=args.sample_batch_size,
700 sample_num_batches=args.sample_batches, 743 sample_num_batches=args.sample_batches,
701 sample_num_steps=args.sample_steps, 744 sample_num_steps=args.sample_steps,
@@ -757,10 +800,7 @@ def main():
757 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 800 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
758 801
759 optimizer = create_optimizer( 802 optimizer = create_optimizer(
760 itertools.chain( 803 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
761 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
762 text_encoder.text_model.embeddings.overlay.parameters(),
763 ),
764 lr=args.learning_rate, 804 lr=args.learning_rate,
765 ) 805 )
766 806
diff --git a/training/functional.py b/training/functional.py
index 7104a88..bd8cbad 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -524,7 +524,7 @@ def train_loop(
524 524
525 lr = lr_scheduler.get_last_lr()[0] 525 lr = lr_scheduler.get_last_lr()[0]
526 if torch.is_tensor(lr): 526 if torch.is_tensor(lr):
527 lr = lr.item 527 lr = lr.item()
528 528
529 lrs.append(lr) 529 lrs.append(lr)
530 530
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 1b5adab..677f5a3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,6 +1,6 @@
1from typing import Optional 1from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager 3from contextlib import contextmanager, nullcontext
4from pathlib import Path 4from pathlib import Path
5 5
6import torch 6import torch
@@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
13from slugify import slugify 13from slugify import slugify
14 14
15from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
16from training.util import EMAModel
16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
17 18
18 19
@@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks(
31 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
32 placeholder_token_ids: list[list[int]], 33 placeholder_token_ids: list[list[int]],
33 gradient_checkpointing: bool = False, 34 gradient_checkpointing: bool = False,
35 use_emb_decay: bool = False,
36 emb_decay_target: float = 0.4,
37 emb_decay: float = 1e-2,
38 use_ema: bool = False,
39 ema_inv_gamma: float = 1.0,
40 ema_power: int = 1,
41 ema_max_decay: float = 0.9999,
34 sample_batch_size: int = 1, 42 sample_batch_size: int = 1,
35 sample_num_batches: int = 1, 43 sample_num_batches: int = 1,
36 sample_num_steps: int = 20, 44 sample_num_steps: int = 20,
@@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks(
63 image_size=sample_image_size, 71 image_size=sample_image_size,
64 ) 72 )
65 73
74 if use_ema:
75 ema_embeddings = EMAModel(
76 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
77 inv_gamma=ema_inv_gamma,
78 power=ema_power,
79 max_value=ema_max_decay,
80 )
81 ema_embeddings.to(accelerator.device)
82 else:
83 ema_embeddings = None
84
85 def ema_context():
86 if ema_embeddings is not None:
87 return ema_embeddings.apply_temporary(
88 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
89 )
90 else:
91 return nullcontext()
92
66 def on_accum_model(): 93 def on_accum_model():
67 return text_encoder.text_model.embeddings 94 return text_encoder.text_model.embeddings.temp_token_embedding
68 95
69 @contextmanager 96 @contextmanager
70 def on_train(epoch: int): 97 def on_train(epoch: int):
@@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks(
74 @contextmanager 101 @contextmanager
75 def on_eval(): 102 def on_eval():
76 tokenizer.eval() 103 tokenizer.eval()
77 yield 104
105 with ema_context():
106 yield
107
108 @torch.no_grad()
109 def on_before_optimize(lr: float, epoch: int):
110 if use_emb_decay:
111 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
112 return torch.all(w.grad == 0, dim=1)
113
114 @torch.no_grad()
115 def on_after_optimize(zero_ids, lr: float):
116 if ema_embeddings is not None:
117 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
118
119 if use_emb_decay:
120 lambda_ = emb_decay * lr
121
122 if lambda_ != 0:
123 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
124
125 mask = torch.ones(w.shape[0], dtype=torch.bool)
126 mask[zero_ids] = False
127
128 norm = w[mask, :].norm(dim=-1, keepdim=True)
129 w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
130
131 def on_log():
132 if ema_embeddings is not None:
133 return {"ema_decay": ema_embeddings.decay}
134 return {}
78 135
79 @torch.no_grad() 136 @torch.no_grad()
80 def on_checkpoint(step, postfix): 137 def on_checkpoint(step, postfix):
81 print(f"Saving checkpoint for step {step}...") 138 print(f"Saving checkpoint for step {step}...")
82 139
83 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 140 with ema_context():
84 text_encoder.text_model.embeddings.save_embed( 141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
85 ids, 142 text_encoder.text_model.embeddings.save_embed(
86 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 143 ids,
87 ) 144 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
145 )
88 146
89 @torch.no_grad() 147 @torch.no_grad()
90 def on_sample(step): 148 def on_sample(step):
91 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 149 with ema_context():
92 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 150 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
151 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
93 152
94 orig_unet_dtype = unet_.dtype 153 orig_unet_dtype = unet_.dtype
95 orig_text_encoder_dtype = text_encoder_.dtype 154 orig_text_encoder_dtype = text_encoder_.dtype
96 155
97 unet_.to(dtype=weight_dtype) 156 unet_.to(dtype=weight_dtype)
98 text_encoder_.to(dtype=weight_dtype) 157 text_encoder_.to(dtype=weight_dtype)
99 158
100 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 159 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
101 160
102 unet_.to(dtype=orig_unet_dtype) 161 unet_.to(dtype=orig_unet_dtype)
103 text_encoder_.to(dtype=orig_text_encoder_dtype) 162 text_encoder_.to(dtype=orig_text_encoder_dtype)
104 163
105 del unet_ 164 del unet_
106 del text_encoder_ 165 del text_encoder_
107 166
108 if torch.cuda.is_available(): 167 if torch.cuda.is_available():
109 torch.cuda.empty_cache() 168 torch.cuda.empty_cache()
@@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks(
112 on_accum_model=on_accum_model, 171 on_accum_model=on_accum_model,
113 on_train=on_train, 172 on_train=on_train,
114 on_eval=on_eval, 173 on_eval=on_eval,
174 on_before_optimize=on_before_optimize,
175 on_after_optimize=on_after_optimize,
176 on_log=on_log,
115 on_checkpoint=on_checkpoint, 177 on_checkpoint=on_checkpoint,
116 on_sample=on_sample, 178 on_sample=on_sample,
117 ) 179 )