diff options
| -rw-r--r-- | models/clip/embeddings.py | 15 | ||||
| -rw-r--r-- | models/sparse.py | 14 | ||||
| -rw-r--r-- | train_ti.py | 24 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 22 |
5 files changed, 57 insertions, 22 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index a356434..63a141f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 37 | 37 | ||
| 38 | 38 | ||
| 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): |
| 41 | super().__init__(config) | 41 | super().__init__(config) |
| 42 | 42 | ||
| 43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
| @@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 49 | device=self.token_embedding.weight.device, | 49 | device=self.token_embedding.weight.device, |
| 50 | dtype=self.token_embedding.weight.dtype, | 50 | dtype=self.token_embedding.weight.dtype, |
| 51 | ) | 51 | ) |
| 52 | self.alpha = alpha | ||
| 53 | 52 | ||
| 54 | def resize(self, size: int): | 53 | def resize(self, size: int): |
| 55 | self.token_override_embedding.resize(size) | 54 | self.token_override_embedding.resize(size) |
| @@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 86 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 88 | 87 | ||
| 89 | self.token_embedding.weight.data[token_ids] = initializer | 88 | self.token_embedding.weight.data[token_ids] = initializer |
| 90 | self.token_override_embedding.set(token_ids) | 89 | self.token_override_embedding.set(token_ids, initializer) |
| 91 | 90 | ||
| 92 | def load_embed(self, input_ids: list[int], filename: Path): | 91 | def load_embed(self, input_ids: list[int], filename: Path): |
| 93 | with safe_open(filename, framework="pt", device="cpu") as file: | 92 | with safe_open(filename, framework="pt", device="cpu") as file: |
| @@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 101 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
| 102 | if embs is not None: | 101 | if embs is not None: |
| 103 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
| 104 | self.token_embedding.weight.data[input_ids] += self.alpha * embs | 103 | self.token_embedding.weight.data[input_ids] = embs |
| 105 | self.token_override_embedding.unset(input_ids) | 104 | self.token_override_embedding.unset(input_ids) |
| 106 | 105 | ||
| 107 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 106 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 108 | if isinstance(input_ids, list): | 107 | if isinstance(input_ids, list): |
| @@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 111 | embs = self.token_embedding(input_ids) | 110 | embs = self.token_embedding(input_ids) |
| 112 | embs_override, mask = self.token_override_embedding(input_ids) | 111 | embs_override, mask = self.token_override_embedding(input_ids) |
| 113 | if embs_override is not None: | 112 | if embs_override is not None: |
| 114 | embs[mask] += self.alpha * embs_override | 113 | embs[mask] = embs_override |
| 115 | 114 | ||
| 116 | return embs | 115 | return embs |
| 117 | 116 | ||
| @@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 135 | return embeddings | 134 | return embeddings |
| 136 | 135 | ||
| 137 | 136 | ||
| 138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: | 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: |
| 139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) | 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
| 140 | text_encoder.text_model.embeddings = text_embeddings | 139 | text_encoder.text_model.embeddings = text_embeddings |
| 141 | return text_embeddings | 140 | return text_embeddings |
diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
| 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) |
| 15 | 15 | ||
| 16 | def forward(self, input_ids: Optional[torch.LongTensor] = None): | 16 | def forward(self, input_ids: torch.LongTensor): |
| 17 | if input_ids is None: | ||
| 18 | input_ids = torch.arange(self.mapping.shape[0]) | ||
| 19 | |||
| 20 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
| 21 | mask = ~(ids == -1) | 18 | mask = ~(ids == -1) |
| 22 | 19 | ||
| @@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 43 | else: | 40 | else: |
| 44 | return [self.set(id) for id in input_ids] | 41 | return [self.set(id) for id in input_ids] |
| 45 | 42 | ||
| 43 | if tensor is None: | ||
| 44 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
| 45 | |||
| 46 | if tensor.shape[-1] != self.embedding_dim: | ||
| 47 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
| 48 | |||
| 46 | id = self.mapping[input_ids] | 49 | id = self.mapping[input_ids] |
| 47 | 50 | ||
| 48 | if id == -1: | 51 | if id == -1: |
| @@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
| 50 | self.mapping[input_ids] = id | 53 | self.mapping[input_ids] = id |
| 51 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | 54 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) |
| 52 | 55 | ||
| 53 | self.params[id] = tensor if tensor is not None else torch.zeros( | 56 | self.params[id] = tensor |
| 54 | self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
| 55 | 57 | ||
| 56 | def unset(self, input_ids: torch.LongTensor): | 58 | def unset(self, input_ids: torch.LongTensor): |
| 57 | self.mapping[input_ids] = -1 | 59 | self.mapping[input_ids] = -1 |
diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -353,7 +353,7 @@ def parse_args(): | |||
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| 354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 355 | type=float, | 355 | type=float, |
| 356 | default=1e-2, | 356 | default=0, |
| 357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 358 | ) | 358 | ) |
| 359 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -451,10 +451,21 @@ def parse_args(): | |||
| 451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 452 | ) | 452 | ) |
| 453 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--emb_alpha", | 454 | "--use_emb_decay", |
| 455 | default=1.0, | 455 | action="store_true", |
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e+2, | ||
| 456 | type=float, | 467 | type=float, |
| 457 | help="Embedding alpha." | 468 | help="Embedding decay factor." |
| 458 | ) | 469 | ) |
| 459 | parser.add_argument( | 470 | parser.add_argument( |
| 460 | "--noise_timesteps", | 471 | "--noise_timesteps", |
| @@ -600,7 +611,7 @@ def main(): | |||
| 600 | save_args(output_dir, args) | 611 | save_args(output_dir, args) |
| 601 | 612 | ||
| 602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 603 | args.pretrained_model_name_or_path, args.emb_alpha) | 614 | args.pretrained_model_name_or_path) |
| 604 | 615 | ||
| 605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 606 | tokenizer.set_dropout(args.vector_dropout) | 617 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -744,6 +755,9 @@ def main(): | |||
| 744 | tokenizer=tokenizer, | 755 | tokenizer=tokenizer, |
| 745 | sample_scheduler=sample_scheduler, | 756 | sample_scheduler=sample_scheduler, |
| 746 | checkpoint_output_dir=checkpoint_output_dir, | 757 | checkpoint_output_dir=checkpoint_output_dir, |
| 758 | use_emb_decay=args.use_emb_decay, | ||
| 759 | emb_decay_target=args.emb_decay_target, | ||
| 760 | emb_decay=args.emb_decay, | ||
| 747 | use_ema=args.use_ema, | 761 | use_ema=args.use_ema, |
| 748 | ema_inv_gamma=args.ema_inv_gamma, | 762 | ema_inv_gamma=args.ema_inv_gamma, |
| 749 | ema_power=args.ema_power, | 763 | ema_power=args.ema_power, |
diff --git a/training/functional.py b/training/functional.py index 1d8e2ee..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
| 73 | return grid | 73 | return grid |
| 74 | 74 | ||
| 75 | 75 | ||
| 76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | 76 | def get_models(pretrained_model_name_or_path: str): |
| 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | |||
| 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
| 84 | 84 | ||
| 85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) | 85 | embeddings = patch_managed_embeddings(text_encoder) |
| 86 | 86 | ||
| 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 88 | 88 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 95128da..9df160a 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks( | |||
| 31 | seed: int, | 31 | seed: int, |
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
| 33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
| 34 | use_emb_decay: bool = False, | ||
| 35 | emb_decay_target: float = 0.4, | ||
| 36 | emb_decay: float = 1e-2, | ||
| 34 | use_ema: bool = False, | 37 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 38 | ema_inv_gamma: float = 1.0, |
| 36 | ema_power: int = 1, | 39 | ema_power: int = 1, |
| @@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks( | |||
| 102 | yield | 105 | yield |
| 103 | 106 | ||
| 104 | @torch.no_grad() | 107 | @torch.no_grad() |
| 105 | def on_after_optimize(zero_ids, lr: float): | 108 | def on_before_optimize(lr: float, epoch: int): |
| 109 | if use_emb_decay: | ||
| 110 | return torch.stack([ | ||
| 111 | p | ||
| 112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
| 113 | if p.grad is not None | ||
| 114 | ]) | ||
| 115 | |||
| 116 | @torch.no_grad() | ||
| 117 | def on_after_optimize(w, lr: float): | ||
| 106 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
| 108 | 120 | ||
| 121 | if use_emb_decay: | ||
| 122 | lambda_ = emb_decay * lr | ||
| 123 | |||
| 124 | if lambda_ != 0: | ||
| 125 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 126 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 127 | |||
| 109 | def on_log(): | 128 | def on_log(): |
| 110 | if ema_embeddings is not None: | 129 | if ema_embeddings is not None: |
| 111 | return {"ema_decay": ema_embeddings.decay} | 130 | return {"ema_decay": ema_embeddings.decay} |
| @@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks( | |||
| 149 | on_accum_model=on_accum_model, | 168 | on_accum_model=on_accum_model, |
| 150 | on_train=on_train, | 169 | on_train=on_train, |
| 151 | on_eval=on_eval, | 170 | on_eval=on_eval, |
| 171 | on_before_optimize=on_before_optimize, | ||
| 152 | on_after_optimize=on_after_optimize, | 172 | on_after_optimize=on_after_optimize, |
| 153 | on_log=on_log, | 173 | on_log=on_log, |
| 154 | on_checkpoint=on_checkpoint, | 174 | on_checkpoint=on_checkpoint, |
