diff options
| -rw-r--r-- | models/clip/embeddings.py | 53 | ||||
| -rw-r--r-- | train_ti.py | 24 | ||||
| -rw-r--r-- | training/strategy/ti.py | 30 |
3 files changed, 44 insertions, 63 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9abd1bb..88e0cc0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -31,25 +31,47 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 31 | return new_embedding | 31 | return new_embedding |
| 32 | 32 | ||
| 33 | 33 | ||
| 34 | class OverlayLinear(nn.Module): | ||
| 35 | def __init__(self, in_features, out_features, rank=4): | ||
| 36 | super().__init__() | ||
| 37 | |||
| 38 | if rank > min(in_features, out_features): | ||
| 39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
| 40 | |||
| 41 | self.rank = rank | ||
| 42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
| 43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
| 44 | self.reset() | ||
| 45 | |||
| 46 | def reset(self): | ||
| 47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
| 48 | nn.init.zeros_(self.up.weight) | ||
| 49 | |||
| 50 | def forward(self, hidden_states): | ||
| 51 | orig_dtype = hidden_states.dtype | ||
| 52 | dtype = self.down.weight.dtype | ||
| 53 | |||
| 54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
| 55 | up_hidden_states = self.up(down_hidden_states) | ||
| 56 | |||
| 57 | return up_hidden_states.to(orig_dtype) | ||
| 58 | |||
| 59 | |||
| 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | 61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): |
| 36 | super().__init__(config) | 62 | super().__init__(config) |
| 37 | 63 | ||
| 38 | self.token_embedding = embeddings.token_embedding | 64 | self.token_embedding = embeddings.token_embedding |
| 39 | self.position_embedding = embeddings.position_embedding | 65 | self.position_embedding = embeddings.position_embedding |
| 40 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
| 41 | 67 | ||
| 42 | self.temp_token_embedding = nn.Embedding( | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
| 43 | self.token_embedding.num_embeddings, | ||
| 44 | self.token_embedding.embedding_dim, | ||
| 45 | device=self.token_embedding.weight.device, | ||
| 46 | dtype=self.token_embedding.weight.dtype | ||
| 47 | ) | ||
| 48 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 49 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 50 | 70 | ||
| 71 | def reset_overlay(self): | ||
| 72 | self.overlay.reset() | ||
| 73 | |||
| 51 | def resize(self, size: int): | 74 | def resize(self, size: int): |
| 52 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
| 53 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 54 | 76 | ||
| 55 | def add_embed( | 77 | def add_embed( |
| @@ -74,8 +96,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | initializer = self.get_embed(initializer) | 96 | initializer = self.get_embed(initializer) |
| 75 | 97 | ||
| 76 | initializer = initializer.to( | 98 | initializer = initializer.to( |
| 77 | device=self.temp_token_embedding.weight.device, | 99 | device=self.token_embedding.weight.device, |
| 78 | dtype=self.temp_token_embedding.weight.dtype, | 100 | dtype=self.token_embedding.weight.dtype, |
| 79 | ) | 101 | ) |
| 80 | 102 | ||
| 81 | if initializer_noise != 0: | 103 | if initializer_noise != 0: |
| @@ -84,7 +106,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 84 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 106 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 85 | 107 | ||
| 86 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 87 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 88 | self.token_embedding.weight.data[token_ids] = initializer | 109 | self.token_embedding.weight.data[token_ids] = initializer |
| 89 | 110 | ||
| 90 | def load_embed(self, input_ids: list[int], filename: Path): | 111 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -95,7 +116,10 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 95 | save_file({"embed": self.get_embed(input_ids)}, filename) | 116 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 96 | 117 | ||
| 97 | def persist(self): | 118 | def persist(self): |
| 98 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( |
| 120 | self.token_embedding.weight.data[self.temp_token_ids] | ||
| 121 | ) | ||
| 122 | self.overlay.reset() | ||
| 99 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 100 | 124 | ||
| 101 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 125 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| @@ -103,9 +127,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 103 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 104 | 128 | ||
| 105 | embeds = self.token_embedding(input_ids) | 129 | embeds = self.token_embedding(input_ids) |
| 106 | |||
| 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 108 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 131 | embeds[mask] += self.overlay(embeds[mask]) |
| 109 | 132 | ||
| 110 | return embeds | 133 | return embeds |
| 111 | 134 | ||
diff --git a/train_ti.py b/train_ti.py index 5482326..0ce0056 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -353,7 +353,7 @@ def parse_args(): | |||
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| 354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 355 | type=float, | 355 | type=float, |
| 356 | default=0, | 356 | default=1e-2, |
| 357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 358 | ) | 358 | ) |
| 359 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -451,23 +451,6 @@ def parse_args(): | |||
| 451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 452 | ) | 452 | ) |
| 453 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--use_emb_decay", | ||
| 455 | action="store_true", | ||
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e2, | ||
| 467 | type=float, | ||
| 468 | help="Embedding decay factor." | ||
| 469 | ) | ||
| 470 | parser.add_argument( | ||
| 471 | "--noise_timesteps", | 454 | "--noise_timesteps", |
| 472 | type=int, | 455 | type=int, |
| 473 | default=1000, | 456 | default=1000, |
| @@ -732,9 +715,6 @@ def main(): | |||
| 732 | sample_scheduler=sample_scheduler, | 715 | sample_scheduler=sample_scheduler, |
| 733 | checkpoint_output_dir=checkpoint_output_dir, | 716 | checkpoint_output_dir=checkpoint_output_dir, |
| 734 | gradient_checkpointing=args.gradient_checkpointing, | 717 | gradient_checkpointing=args.gradient_checkpointing, |
| 735 | use_emb_decay=args.use_emb_decay, | ||
| 736 | emb_decay_target=args.emb_decay_target, | ||
| 737 | emb_decay=args.emb_decay, | ||
| 738 | use_ema=args.use_ema, | 718 | use_ema=args.use_ema, |
| 739 | ema_inv_gamma=args.ema_inv_gamma, | 719 | ema_inv_gamma=args.ema_inv_gamma, |
| 740 | ema_power=args.ema_power, | 720 | ema_power=args.ema_power, |
| @@ -800,7 +780,7 @@ def main(): | |||
| 800 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 801 | 781 | ||
| 802 | optimizer = create_optimizer( | 782 | optimizer = create_optimizer( |
| 803 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 783 | text_encoder.text_model.embeddings.overlay.parameters(), |
| 804 | lr=args.learning_rate, | 784 | lr=args.learning_rate, |
| 805 | ) | 785 | ) |
| 806 | 786 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index b9a5547..19b8d25 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -32,9 +32,6 @@ def textual_inversion_strategy_callbacks( | |||
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
| 33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
| 34 | gradient_checkpointing: bool = False, | 34 | gradient_checkpointing: bool = False, |
| 35 | use_emb_decay: bool = False, | ||
| 36 | emb_decay_target: float = 0.4, | ||
| 37 | emb_decay: float = 1e-2, | ||
| 38 | use_ema: bool = False, | 35 | use_ema: bool = False, |
| 39 | ema_inv_gamma: float = 1.0, | 36 | ema_inv_gamma: float = 1.0, |
| 40 | ema_power: int = 1, | 37 | ema_power: int = 1, |
| @@ -73,7 +70,7 @@ def textual_inversion_strategy_callbacks( | |||
| 73 | 70 | ||
| 74 | if use_ema: | 71 | if use_ema: |
| 75 | ema_embeddings = EMAModel( | 72 | ema_embeddings = EMAModel( |
| 76 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 73 | text_encoder.text_model.embeddings.overlay.parameters(), |
| 77 | inv_gamma=ema_inv_gamma, | 74 | inv_gamma=ema_inv_gamma, |
| 78 | power=ema_power, | 75 | power=ema_power, |
| 79 | max_value=ema_max_decay, | 76 | max_value=ema_max_decay, |
| @@ -85,13 +82,13 @@ def textual_inversion_strategy_callbacks( | |||
| 85 | def ema_context(): | 82 | def ema_context(): |
| 86 | if ema_embeddings is not None: | 83 | if ema_embeddings is not None: |
| 87 | return ema_embeddings.apply_temporary( | 84 | return ema_embeddings.apply_temporary( |
| 88 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 85 | text_encoder.text_model.embeddings.overlay.parameters() |
| 89 | ) | 86 | ) |
| 90 | else: | 87 | else: |
| 91 | return nullcontext() | 88 | return nullcontext() |
| 92 | 89 | ||
| 93 | def on_accum_model(): | 90 | def on_accum_model(): |
| 94 | return text_encoder.text_model.embeddings.temp_token_embedding | 91 | return text_encoder.text_model.embeddings.overlay |
| 95 | 92 | ||
| 96 | @contextmanager | 93 | @contextmanager |
| 97 | def on_train(epoch: int): | 94 | def on_train(epoch: int): |
| @@ -106,27 +103,9 @@ def textual_inversion_strategy_callbacks( | |||
| 106 | yield | 103 | yield |
| 107 | 104 | ||
| 108 | @torch.no_grad() | 105 | @torch.no_grad() |
| 109 | def on_before_optimize(lr: float, epoch: int): | ||
| 110 | if use_emb_decay: | ||
| 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 112 | return torch.all(w.grad == 0, dim=1) | ||
| 113 | |||
| 114 | @torch.no_grad() | ||
| 115 | def on_after_optimize(zero_ids, lr: float): | 106 | def on_after_optimize(zero_ids, lr: float): |
| 116 | if ema_embeddings is not None: | 107 | if ema_embeddings is not None: |
| 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) |
| 118 | |||
| 119 | if use_emb_decay: | ||
| 120 | lambda_ = emb_decay * lr | ||
| 121 | |||
| 122 | if lambda_ != 0: | ||
| 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 124 | |||
| 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
| 126 | mask[zero_ids] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 130 | 109 | ||
| 131 | def on_log(): | 110 | def on_log(): |
| 132 | if ema_embeddings is not None: | 111 | if ema_embeddings is not None: |
| @@ -171,7 +150,6 @@ def textual_inversion_strategy_callbacks( | |||
| 171 | on_accum_model=on_accum_model, | 150 | on_accum_model=on_accum_model, |
| 172 | on_train=on_train, | 151 | on_train=on_train, |
| 173 | on_eval=on_eval, | 152 | on_eval=on_eval, |
| 174 | on_before_optimize=on_before_optimize, | ||
| 175 | on_after_optimize=on_after_optimize, | 153 | on_after_optimize=on_after_optimize, |
| 176 | on_log=on_log, | 154 | on_log=on_log, |
| 177 | on_checkpoint=on_checkpoint, | 155 | on_checkpoint=on_checkpoint, |
