diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-03 18:52:30 +0200 |
| commit | e68cb3542e08c9f22ce8a94fd88bebe0c121ca17 (patch) | |
| tree | 87fbb9d92233aa1bb7342e31aec64d6d375f41e1 | |
| parent | TI: No tag dropout by default (diff) | |
| download | textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.gz textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.tar.bz2 textual-inversion-diff-e68cb3542e08c9f22ce8a94fd88bebe0c121ca17.zip | |
TI: Delta learning
| -rw-r--r-- | models/clip/embeddings.py | 50 | ||||
| -rw-r--r-- | train_ti.py | 37 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 23 |
4 files changed, 46 insertions, 68 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 1e21965..d8343a0 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -12,7 +12,7 @@ from transformers.models.clip import CLIPTextConfig | |||
| 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: | 15 | def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: Optional[float] = None) -> nn.Embedding: |
| 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape | 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.shape |
| 17 | 17 | ||
| 18 | if old_num_embeddings == new_num_embeddings: | 18 | if old_num_embeddings == new_num_embeddings: |
| @@ -26,13 +26,16 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 26 | device=old_embedding.weight.device, | 26 | device=old_embedding.weight.device, |
| 27 | dtype=old_embedding.weight.dtype | 27 | dtype=old_embedding.weight.dtype |
| 28 | ) | 28 | ) |
| 29 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 29 | if initializer_factor is not None: |
| 30 | new_embedding.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | ||
| 31 | else: | ||
| 32 | nn.init.zeros_(new_embedding.weight.data) | ||
| 30 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] | 33 | new_embedding.weight.data[:n, :] = old_embedding.weight.data[:n, :] |
| 31 | return new_embedding | 34 | return new_embedding |
| 32 | 35 | ||
| 33 | 36 | ||
| 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 37 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): | 38 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): |
| 36 | super().__init__(config) | 39 | super().__init__(config) |
| 37 | 40 | ||
| 38 | self.token_embedding = embeddings.token_embedding | 41 | self.token_embedding = embeddings.token_embedding |
| @@ -40,17 +43,16 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 40 | self.initializer_factor = config.initializer_factor | 43 | self.initializer_factor = config.initializer_factor |
| 41 | self.alpha = alpha | 44 | self.alpha = alpha |
| 42 | 45 | ||
| 43 | self.temp_token_embedding = nn.Embedding( | 46 | self.temp_token_embedding = nn.ParameterList() |
| 44 | self.token_embedding.num_embeddings, | ||
| 45 | self.token_embedding.embedding_dim, | ||
| 46 | device=self.token_embedding.weight.device, | ||
| 47 | dtype=self.token_embedding.weight.dtype | ||
| 48 | ) | ||
| 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 47 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 51 | 48 | ||
| 52 | def resize(self, size: int): | 49 | def resize(self, size: int): |
| 53 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 50 | for _ in range(len(self.temp_token_embedding), size): |
| 51 | self.temp_token_embedding.append(torch.zeros( | ||
| 52 | self.token_embedding.embedding_dim, | ||
| 53 | device=self.token_embedding.weight.device, | ||
| 54 | dtype=self.token_embedding.weight.dtype, | ||
| 55 | )) | ||
| 54 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 56 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 55 | 57 | ||
| 56 | def add_embed( | 58 | def add_embed( |
| @@ -85,7 +87,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 85 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 87 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 86 | 88 | ||
| 87 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 89 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 88 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 89 | self.token_embedding.weight.data[token_ids] = initializer | 90 | self.token_embedding.weight.data[token_ids] = initializer |
| 90 | 91 | ||
| 91 | def load_embed(self, input_ids: list[int], filename: Path): | 92 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -96,16 +97,31 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 96 | save_file({"embed": self.get_embed(input_ids)}, filename) | 97 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 97 | 98 | ||
| 98 | def persist(self): | 99 | def persist(self): |
| 99 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 100 | for id, emb in zip(self.temp_token_ids, self.temp_token_embedding): |
| 101 | self.token_embedding.weight.data[id] += self.alpha * emb | ||
| 102 | nn.init.zeros_(emb) | ||
| 100 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 103 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 101 | 104 | ||
| 102 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 105 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 103 | if isinstance(input_ids, list): | 106 | if isinstance(input_ids, list): |
| 104 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 107 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 105 | 108 | ||
| 109 | all_temp_token_ids = self.temp_token_ids.to(input_ids.device) | ||
| 110 | |||
| 106 | embeds = self.token_embedding(input_ids) | 111 | embeds = self.token_embedding(input_ids) |
| 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 112 | mask = torch.isin(input_ids, all_temp_token_ids) |
| 108 | embeds[mask] = self.temp_token_embedding(input_ids[mask]) | 113 | temp_token_ids = input_ids[mask] |
| 114 | |||
| 115 | temp_token_ids = temp_token_ids.unsqueeze(1) | ||
| 116 | all_temp_token_ids = all_temp_token_ids.unsqueeze(0) | ||
| 117 | temp_token_ids = torch.nonzero(temp_token_ids == all_temp_token_ids)[:, 1].squeeze() | ||
| 118 | |||
| 119 | if len(temp_token_ids): | ||
| 120 | embeds_override = torch.stack([ | ||
| 121 | self.temp_token_embedding[id] | ||
| 122 | for id in temp_token_ids | ||
| 123 | ]) | ||
| 124 | embeds[mask] += self.alpha * embeds_override | ||
| 109 | 125 | ||
| 110 | return embeds | 126 | return embeds |
| 111 | 127 | ||
| @@ -129,7 +145,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 129 | return embeddings | 145 | return embeddings |
| 130 | 146 | ||
| 131 | 147 | ||
| 132 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | 148 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: |
| 133 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 149 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) |
| 134 | text_encoder.text_model.embeddings = text_embeddings | 150 | text_encoder.text_model.embeddings = text_embeddings |
| 135 | return text_embeddings | 151 | return text_embeddings |
diff --git a/train_ti.py b/train_ti.py index 8dde1ba..0ad7574 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -353,7 +353,7 @@ def parse_args(): | |||
| 353 | parser.add_argument( | 353 | parser.add_argument( |
| 354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 355 | type=float, | 355 | type=float, |
| 356 | default=0, | 356 | default=1e-2, |
| 357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 358 | ) | 358 | ) |
| 359 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -451,21 +451,10 @@ def parse_args(): | |||
| 451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 452 | ) | 452 | ) |
| 453 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--use_emb_decay", | 454 | "--emb_alpha", |
| 455 | action="store_true", | 455 | default=1.0, |
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e2, | ||
| 467 | type=float, | 456 | type=float, |
| 468 | help="Embedding decay factor." | 457 | help="Embedding alpha." |
| 469 | ) | 458 | ) |
| 470 | parser.add_argument( | 459 | parser.add_argument( |
| 471 | "--noise_timesteps", | 460 | "--noise_timesteps", |
| @@ -567,16 +556,16 @@ def parse_args(): | |||
| 567 | raise ValueError("You must specify --output_dir") | 556 | raise ValueError("You must specify --output_dir") |
| 568 | 557 | ||
| 569 | if args.adam_beta1 is None: | 558 | if args.adam_beta1 is None: |
| 570 | if args.optimizer in ('adam', 'adam8bit'): | 559 | if args.optimizer == 'lion': |
| 571 | args.adam_beta1 = 0.9 | ||
| 572 | elif args.optimizer == 'lion': | ||
| 573 | args.adam_beta1 = 0.95 | 560 | args.adam_beta1 = 0.95 |
| 561 | else: | ||
| 562 | args.adam_beta1 = 0.9 | ||
| 574 | 563 | ||
| 575 | if args.adam_beta2 is None: | 564 | if args.adam_beta2 is None: |
| 576 | if args.optimizer in ('adam', 'adam8bit'): | 565 | if args.optimizer == 'lion': |
| 577 | args.adam_beta2 = 0.999 | ||
| 578 | elif args.optimizer == 'lion': | ||
| 579 | args.adam_beta2 = 0.98 | 566 | args.adam_beta2 = 0.98 |
| 567 | else: | ||
| 568 | args.adam_beta2 = 0.999 | ||
| 580 | 569 | ||
| 581 | return args | 570 | return args |
| 582 | 571 | ||
| @@ -611,7 +600,7 @@ def main(): | |||
| 611 | save_args(output_dir, args) | 600 | save_args(output_dir, args) |
| 612 | 601 | ||
| 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 614 | args.pretrained_model_name_or_path) | 603 | args.pretrained_model_name_or_path, args.emb_alpha) |
| 615 | 604 | ||
| 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 617 | tokenizer.set_dropout(args.vector_dropout) | 606 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -755,10 +744,6 @@ def main(): | |||
| 755 | tokenizer=tokenizer, | 744 | tokenizer=tokenizer, |
| 756 | sample_scheduler=sample_scheduler, | 745 | sample_scheduler=sample_scheduler, |
| 757 | checkpoint_output_dir=checkpoint_output_dir, | 746 | checkpoint_output_dir=checkpoint_output_dir, |
| 758 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 759 | use_emb_decay=args.use_emb_decay, | ||
| 760 | emb_decay_target=args.emb_decay_target, | ||
| 761 | emb_decay=args.emb_decay, | ||
| 762 | use_ema=args.use_ema, | 747 | use_ema=args.use_ema, |
| 763 | ema_inv_gamma=args.ema_inv_gamma, | 748 | ema_inv_gamma=args.ema_inv_gamma, |
| 764 | ema_power=args.ema_power, | 749 | ema_power=args.ema_power, |
diff --git a/training/functional.py b/training/functional.py index 96ecbc1..1d8e2ee 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
| 73 | return grid | 73 | return grid |
| 74 | 74 | ||
| 75 | 75 | ||
| 76 | def get_models(pretrained_model_name_or_path: str): | 76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): |
| 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| @@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
| 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
| 84 | 84 | ||
| 85 | embeddings = patch_managed_embeddings(text_encoder) | 85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) |
| 86 | 86 | ||
| 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
| 88 | 88 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index c7520ed..16baa34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -31,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
| 31 | seed: int, | 31 | seed: int, |
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
| 33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
| 34 | gradient_checkpointing: bool = False, | ||
| 35 | use_emb_decay: bool = False, | ||
| 36 | emb_decay_target: float = 0.4, | ||
| 37 | emb_decay: float = 1e-2, | ||
| 38 | use_ema: bool = False, | 34 | use_ema: bool = False, |
| 39 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
| 40 | ema_power: int = 1, | 36 | ema_power: int = 1, |
| @@ -106,28 +102,10 @@ def textual_inversion_strategy_callbacks( | |||
| 106 | yield | 102 | yield |
| 107 | 103 | ||
| 108 | @torch.no_grad() | 104 | @torch.no_grad() |
| 109 | def on_before_optimize(lr: float, epoch: int): | ||
| 110 | if use_emb_decay: | ||
| 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 112 | return torch.all(w.grad == 0, dim=1) | ||
| 113 | |||
| 114 | @torch.no_grad() | ||
| 115 | def on_after_optimize(zero_ids, lr: float): | 105 | def on_after_optimize(zero_ids, lr: float): |
| 116 | if ema_embeddings is not None: | 106 | if ema_embeddings is not None: |
| 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 107 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
| 118 | 108 | ||
| 119 | if use_emb_decay: | ||
| 120 | lambda_ = emb_decay * lr | ||
| 121 | |||
| 122 | if lambda_ != 0: | ||
| 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 124 | |||
| 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
| 126 | mask[zero_ids] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 130 | |||
| 131 | def on_log(): | 109 | def on_log(): |
| 132 | if ema_embeddings is not None: | 110 | if ema_embeddings is not None: |
| 133 | return {"ema_decay": ema_embeddings.decay} | 111 | return {"ema_decay": ema_embeddings.decay} |
| @@ -171,7 +149,6 @@ def textual_inversion_strategy_callbacks( | |||
| 171 | on_accum_model=on_accum_model, | 149 | on_accum_model=on_accum_model, |
| 172 | on_train=on_train, | 150 | on_train=on_train, |
| 173 | on_eval=on_eval, | 151 | on_eval=on_eval, |
| 174 | on_before_optimize=on_before_optimize, | ||
| 175 | on_after_optimize=on_after_optimize, | 152 | on_after_optimize=on_after_optimize, |
| 176 | on_log=on_log, | 153 | on_log=on_log, |
| 177 | on_checkpoint=on_checkpoint, | 154 | on_checkpoint=on_checkpoint, |
