diff options
| -rw-r--r-- | models/clip/embeddings.py | 42 | ||||
| -rw-r--r-- | train_ti.py | 52 | ||||
| -rw-r--r-- | training/functional.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 100 |
4 files changed, 132 insertions, 64 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index c9c788c..1e21965 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -31,41 +31,15 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
| 31 | return new_embedding | 31 | return new_embedding |
| 32 | 32 | ||
| 33 | 33 | ||
| 34 | class OverlayLinear(nn.Module): | ||
| 35 | def __init__(self, in_features, out_features, rank=4): | ||
| 36 | super().__init__() | ||
| 37 | |||
| 38 | if rank > min(in_features, out_features): | ||
| 39 | raise ValueError(f"Rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
| 40 | |||
| 41 | self.rank = rank | ||
| 42 | self.down = nn.Linear(in_features, rank, bias=False) | ||
| 43 | self.up = nn.Linear(rank, out_features, bias=False) | ||
| 44 | self.reset() | ||
| 45 | |||
| 46 | def reset(self): | ||
| 47 | nn.init.normal_(self.down.weight, std=1 / self.rank) | ||
| 48 | nn.init.zeros_(self.up.weight) | ||
| 49 | |||
| 50 | def forward(self, hidden_states): | ||
| 51 | orig_dtype = hidden_states.dtype | ||
| 52 | dtype = self.down.weight.dtype | ||
| 53 | |||
| 54 | down_hidden_states = self.down(hidden_states.to(dtype)) | ||
| 55 | up_hidden_states = self.up(down_hidden_states) | ||
| 56 | |||
| 57 | return up_hidden_states.to(orig_dtype) | ||
| 58 | |||
| 59 | |||
| 60 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 34 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 61 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, rank: int = 128): | 35 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0, rank: int = 4): |
| 62 | super().__init__(config) | 36 | super().__init__(config) |
| 63 | 37 | ||
| 64 | self.token_embedding = embeddings.token_embedding | 38 | self.token_embedding = embeddings.token_embedding |
| 65 | self.position_embedding = embeddings.position_embedding | 39 | self.position_embedding = embeddings.position_embedding |
| 66 | self.initializer_factor = config.initializer_factor | 40 | self.initializer_factor = config.initializer_factor |
| 41 | self.alpha = alpha | ||
| 67 | 42 | ||
| 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | ||
| 69 | self.temp_token_embedding = nn.Embedding( | 43 | self.temp_token_embedding = nn.Embedding( |
| 70 | self.token_embedding.num_embeddings, | 44 | self.token_embedding.num_embeddings, |
| 71 | self.token_embedding.embedding_dim, | 45 | self.token_embedding.embedding_dim, |
| @@ -75,9 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | 49 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() |
| 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 50 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 77 | 51 | ||
| 78 | def reset_overlay(self): | ||
| 79 | self.overlay.reset() | ||
| 80 | |||
| 81 | def resize(self, size: int): | 52 | def resize(self, size: int): |
| 82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | 53 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) |
| 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 54 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| @@ -125,9 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 125 | save_file({"embed": self.get_embed(input_ids)}, filename) | 96 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 126 | 97 | ||
| 127 | def persist(self): | 98 | def persist(self): |
| 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] | 99 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) | ||
| 130 | self.overlay.reset() | ||
| 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 100 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 132 | 101 | ||
| 133 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 102 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| @@ -135,11 +104,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 104 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 136 | 105 | ||
| 137 | embeds = self.token_embedding(input_ids) | 106 | embeds = self.token_embedding(input_ids) |
| 138 | |||
| 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 107 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 140 | 108 | embeds[mask] = self.temp_token_embedding(input_ids[mask]) | |
| 141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
| 142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
| 143 | 109 | ||
| 144 | return embeds | 110 | return embeds |
| 145 | 111 | ||
diff --git a/train_ti.py b/train_ti.py index 26ac384..5482326 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,7 +1,6 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 5 | from functools import partial | 4 | from functools import partial |
| 6 | from pathlib import Path | 5 | from pathlib import Path |
| 7 | import math | 6 | import math |
| @@ -308,6 +307,26 @@ def parse_args(): | |||
| 308 | help="Minimum learning rate in the lr scheduler." | 307 | help="Minimum learning rate in the lr scheduler." |
| 309 | ) | 308 | ) |
| 310 | parser.add_argument( | 309 | parser.add_argument( |
| 310 | "--use_ema", | ||
| 311 | action="store_true", | ||
| 312 | help="Whether to use EMA model." | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--ema_inv_gamma", | ||
| 316 | type=float, | ||
| 317 | default=1.0 | ||
| 318 | ) | ||
| 319 | parser.add_argument( | ||
| 320 | "--ema_power", | ||
| 321 | type=float, | ||
| 322 | default=4/5 | ||
| 323 | ) | ||
| 324 | parser.add_argument( | ||
| 325 | "--ema_max_decay", | ||
| 326 | type=float, | ||
| 327 | default=0.9999 | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 311 | "--optimizer", | 330 | "--optimizer", |
| 312 | type=str, | 331 | type=str, |
| 313 | default="dadan", | 332 | default="dadan", |
| @@ -334,7 +353,7 @@ def parse_args(): | |||
| 334 | parser.add_argument( | 353 | parser.add_argument( |
| 335 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
| 336 | type=float, | 355 | type=float, |
| 337 | default=1e-2, | 356 | default=0, |
| 338 | help="Weight decay to use." | 357 | help="Weight decay to use." |
| 339 | ) | 358 | ) |
| 340 | parser.add_argument( | 359 | parser.add_argument( |
| @@ -432,6 +451,23 @@ def parse_args(): | |||
| 432 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
| 433 | ) | 452 | ) |
| 434 | parser.add_argument( | 453 | parser.add_argument( |
| 454 | "--use_emb_decay", | ||
| 455 | action="store_true", | ||
| 456 | help="Whether to use embedding decay." | ||
| 457 | ) | ||
| 458 | parser.add_argument( | ||
| 459 | "--emb_decay_target", | ||
| 460 | default=0.4, | ||
| 461 | type=float, | ||
| 462 | help="Embedding decay target." | ||
| 463 | ) | ||
| 464 | parser.add_argument( | ||
| 465 | "--emb_decay", | ||
| 466 | default=1e2, | ||
| 467 | type=float, | ||
| 468 | help="Embedding decay factor." | ||
| 469 | ) | ||
| 470 | parser.add_argument( | ||
| 435 | "--noise_timesteps", | 471 | "--noise_timesteps", |
| 436 | type=int, | 472 | type=int, |
| 437 | default=1000, | 473 | default=1000, |
| @@ -696,6 +732,13 @@ def main(): | |||
| 696 | sample_scheduler=sample_scheduler, | 732 | sample_scheduler=sample_scheduler, |
| 697 | checkpoint_output_dir=checkpoint_output_dir, | 733 | checkpoint_output_dir=checkpoint_output_dir, |
| 698 | gradient_checkpointing=args.gradient_checkpointing, | 734 | gradient_checkpointing=args.gradient_checkpointing, |
| 735 | use_emb_decay=args.use_emb_decay, | ||
| 736 | emb_decay_target=args.emb_decay_target, | ||
| 737 | emb_decay=args.emb_decay, | ||
| 738 | use_ema=args.use_ema, | ||
| 739 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 740 | ema_power=args.ema_power, | ||
| 741 | ema_max_decay=args.ema_max_decay, | ||
| 699 | sample_batch_size=args.sample_batch_size, | 742 | sample_batch_size=args.sample_batch_size, |
| 700 | sample_num_batches=args.sample_batches, | 743 | sample_num_batches=args.sample_batches, |
| 701 | sample_num_steps=args.sample_steps, | 744 | sample_num_steps=args.sample_steps, |
| @@ -757,10 +800,7 @@ def main(): | |||
| 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 800 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 758 | 801 | ||
| 759 | optimizer = create_optimizer( | 802 | optimizer = create_optimizer( |
| 760 | itertools.chain( | 803 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 763 | ), | ||
| 764 | lr=args.learning_rate, | 804 | lr=args.learning_rate, |
| 765 | ) | 805 | ) |
| 766 | 806 | ||
diff --git a/training/functional.py b/training/functional.py index 7104a88..bd8cbad 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -524,7 +524,7 @@ def train_loop( | |||
| 524 | 524 | ||
| 525 | lr = lr_scheduler.get_last_lr()[0] | 525 | lr = lr_scheduler.get_last_lr()[0] |
| 526 | if torch.is_tensor(lr): | 526 | if torch.is_tensor(lr): |
| 527 | lr = lr.item | 527 | lr = lr.item() |
| 528 | 528 | ||
| 529 | lrs.append(lr) | 529 | lrs.append(lr) |
| 530 | 530 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 1b5adab..677f5a3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from functools import partial | 2 | from functools import partial |
| 3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | 5 | ||
| 6 | import torch | 6 | import torch |
| @@ -13,6 +13,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 16 | from training.util import EMAModel | ||
| 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 17 | 18 | ||
| 18 | 19 | ||
| @@ -31,6 +32,13 @@ def textual_inversion_strategy_callbacks( | |||
| 31 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
| 32 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
| 33 | gradient_checkpointing: bool = False, | 34 | gradient_checkpointing: bool = False, |
| 35 | use_emb_decay: bool = False, | ||
| 36 | emb_decay_target: float = 0.4, | ||
| 37 | emb_decay: float = 1e-2, | ||
| 38 | use_ema: bool = False, | ||
| 39 | ema_inv_gamma: float = 1.0, | ||
| 40 | ema_power: int = 1, | ||
| 41 | ema_max_decay: float = 0.9999, | ||
| 34 | sample_batch_size: int = 1, | 42 | sample_batch_size: int = 1, |
| 35 | sample_num_batches: int = 1, | 43 | sample_num_batches: int = 1, |
| 36 | sample_num_steps: int = 20, | 44 | sample_num_steps: int = 20, |
| @@ -63,8 +71,27 @@ def textual_inversion_strategy_callbacks( | |||
| 63 | image_size=sample_image_size, | 71 | image_size=sample_image_size, |
| 64 | ) | 72 | ) |
| 65 | 73 | ||
| 74 | if use_ema: | ||
| 75 | ema_embeddings = EMAModel( | ||
| 76 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 77 | inv_gamma=ema_inv_gamma, | ||
| 78 | power=ema_power, | ||
| 79 | max_value=ema_max_decay, | ||
| 80 | ) | ||
| 81 | ema_embeddings.to(accelerator.device) | ||
| 82 | else: | ||
| 83 | ema_embeddings = None | ||
| 84 | |||
| 85 | def ema_context(): | ||
| 86 | if ema_embeddings is not None: | ||
| 87 | return ema_embeddings.apply_temporary( | ||
| 88 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 89 | ) | ||
| 90 | else: | ||
| 91 | return nullcontext() | ||
| 92 | |||
| 66 | def on_accum_model(): | 93 | def on_accum_model(): |
| 67 | return text_encoder.text_model.embeddings | 94 | return text_encoder.text_model.embeddings.temp_token_embedding |
| 68 | 95 | ||
| 69 | @contextmanager | 96 | @contextmanager |
| 70 | def on_train(epoch: int): | 97 | def on_train(epoch: int): |
| @@ -74,36 +101,68 @@ def textual_inversion_strategy_callbacks( | |||
| 74 | @contextmanager | 101 | @contextmanager |
| 75 | def on_eval(): | 102 | def on_eval(): |
| 76 | tokenizer.eval() | 103 | tokenizer.eval() |
| 77 | yield | 104 | |
| 105 | with ema_context(): | ||
| 106 | yield | ||
| 107 | |||
| 108 | @torch.no_grad() | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | ||
| 110 | if use_emb_decay: | ||
| 111 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 112 | return torch.all(w.grad == 0, dim=1) | ||
| 113 | |||
| 114 | @torch.no_grad() | ||
| 115 | def on_after_optimize(zero_ids, lr: float): | ||
| 116 | if ema_embeddings is not None: | ||
| 117 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 118 | |||
| 119 | if use_emb_decay: | ||
| 120 | lambda_ = emb_decay * lr | ||
| 121 | |||
| 122 | if lambda_ != 0: | ||
| 123 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | ||
| 124 | |||
| 125 | mask = torch.ones(w.shape[0], dtype=torch.bool) | ||
| 126 | mask[zero_ids] = False | ||
| 127 | |||
| 128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
| 129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 130 | |||
| 131 | def on_log(): | ||
| 132 | if ema_embeddings is not None: | ||
| 133 | return {"ema_decay": ema_embeddings.decay} | ||
| 134 | return {} | ||
| 78 | 135 | ||
| 79 | @torch.no_grad() | 136 | @torch.no_grad() |
| 80 | def on_checkpoint(step, postfix): | 137 | def on_checkpoint(step, postfix): |
| 81 | print(f"Saving checkpoint for step {step}...") | 138 | print(f"Saving checkpoint for step {step}...") |
| 82 | 139 | ||
| 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 140 | with ema_context(): |
| 84 | text_encoder.text_model.embeddings.save_embed( | 141 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
| 85 | ids, | 142 | text_encoder.text_model.embeddings.save_embed( |
| 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 143 | ids, |
| 87 | ) | 144 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
| 145 | ) | ||
| 88 | 146 | ||
| 89 | @torch.no_grad() | 147 | @torch.no_grad() |
| 90 | def on_sample(step): | 148 | def on_sample(step): |
| 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 149 | with ema_context(): |
| 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
| 93 | 152 | ||
| 94 | orig_unet_dtype = unet_.dtype | 153 | orig_unet_dtype = unet_.dtype |
| 95 | orig_text_encoder_dtype = text_encoder_.dtype | 154 | orig_text_encoder_dtype = text_encoder_.dtype |
| 96 | 155 | ||
| 97 | unet_.to(dtype=weight_dtype) | 156 | unet_.to(dtype=weight_dtype) |
| 98 | text_encoder_.to(dtype=weight_dtype) | 157 | text_encoder_.to(dtype=weight_dtype) |
| 99 | 158 | ||
| 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 159 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
| 101 | 160 | ||
| 102 | unet_.to(dtype=orig_unet_dtype) | 161 | unet_.to(dtype=orig_unet_dtype) |
| 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 162 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
| 104 | 163 | ||
| 105 | del unet_ | 164 | del unet_ |
| 106 | del text_encoder_ | 165 | del text_encoder_ |
| 107 | 166 | ||
| 108 | if torch.cuda.is_available(): | 167 | if torch.cuda.is_available(): |
| 109 | torch.cuda.empty_cache() | 168 | torch.cuda.empty_cache() |
| @@ -112,6 +171,9 @@ def textual_inversion_strategy_callbacks( | |||
| 112 | on_accum_model=on_accum_model, | 171 | on_accum_model=on_accum_model, |
| 113 | on_train=on_train, | 172 | on_train=on_train, |
| 114 | on_eval=on_eval, | 173 | on_eval=on_eval, |
| 174 | on_before_optimize=on_before_optimize, | ||
| 175 | on_after_optimize=on_after_optimize, | ||
| 176 | on_log=on_log, | ||
| 115 | on_checkpoint=on_checkpoint, | 177 | on_checkpoint=on_checkpoint, |
| 116 | on_sample=on_sample, | 178 | on_sample=on_sample, |
| 117 | ) | 179 | ) |
