diff options
| -rw-r--r-- | models/clip/embeddings.py | 19 | ||||
| -rw-r--r-- | train_ti.py | 30 | ||||
| -rw-r--r-- | training/strategy/ti.py | 76 |
3 files changed, 38 insertions, 87 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 66 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
| 67 | 67 | ||
| 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
| 69 | self.temp_token_embedding = nn.Embedding( | ||
| 70 | self.token_embedding.num_embeddings, | ||
| 71 | self.token_embedding.embedding_dim, | ||
| 72 | device=self.token_embedding.weight.device, | ||
| 73 | dtype=self.token_embedding.weight.dtype | ||
| 74 | ) | ||
| 75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
| 69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 70 | 77 | ||
| 71 | def reset_overlay(self): | 78 | def reset_overlay(self): |
| 72 | self.overlay.reset() | 79 | self.overlay.reset() |
| 73 | 80 | ||
| 74 | def resize(self, size: int): | 81 | def resize(self, size: int): |
| 82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
| 75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
| 76 | 84 | ||
| 77 | def add_embed( | 85 | def add_embed( |
| @@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 106 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 114 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 107 | 115 | ||
| 108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 116 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 117 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 109 | self.token_embedding.weight.data[token_ids] = initializer | 118 | self.token_embedding.weight.data[token_ids] = initializer |
| 110 | 119 | ||
| 111 | def load_embed(self, input_ids: list[int], filename: Path): | 120 | def load_embed(self, input_ids: list[int], filename: Path): |
| @@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 116 | save_file({"embed": self.get_embed(input_ids)}, filename) | 125 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 117 | 126 | ||
| 118 | def persist(self): | 127 | def persist(self): |
| 119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( | 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 120 | self.token_embedding.weight.data[self.temp_token_ids] | 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) |
| 121 | ) | ||
| 122 | self.overlay.reset() | 130 | self.overlay.reset() |
| 123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 124 | 132 | ||
| @@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 128 | 136 | ||
| 129 | embeds = self.token_embedding(input_ids) | 137 | embeds = self.token_embedding(input_ids) |
| 138 | |||
| 130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
| 131 | embeds[mask] += self.overlay(embeds[mask]) | 140 | |
| 141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
| 142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
| 132 | 143 | ||
| 133 | return embeds | 144 | return embeds |
| 134 | 145 | ||
diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 4 | from functools import partial | 5 | from functools import partial |
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | import math | 7 | import math |
| @@ -307,26 +308,6 @@ def parse_args(): | |||
| 307 | help="Minimum learning rate in the lr scheduler." | 308 | help="Minimum learning rate in the lr scheduler." |
| 308 | ) | 309 | ) |
| 309 | parser.add_argument( | 310 | parser.add_argument( |
| 310 | "--use_ema", | ||
| 311 | action="store_true", | ||
| 312 | help="Whether to use EMA model." | ||
| 313 | ) | ||
| 314 | parser.add_argument( | ||
| 315 | "--ema_inv_gamma", | ||
| 316 | type=float, | ||
| 317 | default=1.0 | ||
| 318 | ) | ||
| 319 | parser.add_argument( | ||
| 320 | "--ema_power", | ||
| 321 | type=float, | ||
| 322 | default=4/5 | ||
| 323 | ) | ||
| 324 | parser.add_argument( | ||
| 325 | "--ema_max_decay", | ||
| 326 | type=float, | ||
| 327 | default=0.9999 | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 330 | "--optimizer", | 311 | "--optimizer", |
| 331 | type=str, | 312 | type=str, |
| 332 | default="dadan", | 313 | default="dadan", |
| @@ -715,10 +696,6 @@ def main(): | |||
| 715 | sample_scheduler=sample_scheduler, | 696 | sample_scheduler=sample_scheduler, |
| 716 | checkpoint_output_dir=checkpoint_output_dir, | 697 | checkpoint_output_dir=checkpoint_output_dir, |
| 717 | gradient_checkpointing=args.gradient_checkpointing, | 698 | gradient_checkpointing=args.gradient_checkpointing, |
| 718 | use_ema=args.use_ema, | ||
| 719 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 720 | ema_power=args.ema_power, | ||
| 721 | ema_max_decay=args.ema_max_decay, | ||
| 722 | sample_batch_size=args.sample_batch_size, | 699 | sample_batch_size=args.sample_batch_size, |
| 723 | sample_num_batches=args.sample_batches, | 700 | sample_num_batches=args.sample_batches, |
| 724 | sample_num_steps=args.sample_steps, | 701 | sample_num_steps=args.sample_steps, |
| @@ -780,7 +757,10 @@ def main(): | |||
| 780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 781 | 758 | ||
| 782 | optimizer = create_optimizer( | 759 | optimizer = create_optimizer( |
| 783 | text_encoder.text_model.embeddings.overlay.parameters(), | 760 | itertools.chain( |
| 761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 763 | ), | ||
| 784 | lr=args.learning_rate, | 764 | lr=args.learning_rate, |
| 785 | ) | 765 | ) |
| 786 | 766 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 19b8d25..33f5fb9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from functools import partial | 2 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | 5 | ||
| 6 | import torch | 6 | import torch |
| @@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 16 | from training.util import EMAModel | ||
| 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 18 | 17 | ||
| 19 | 18 | ||
| @@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
| 32 | placeholder_tokens: list[str], | 31 | placeholder_tokens: list[str], |
| 33 | placeholder_token_ids: list[list[int]], | 32 | placeholder_token_ids: list[list[int]], |
| 34 | gradient_checkpointing: bool = False, | 33 | gradient_checkpointing: bool = False, |
| 35 | use_ema: bool = False, | ||
| 36 | ema_inv_gamma: float = 1.0, | ||
| 37 | ema_power: int = 1, | ||
| 38 | ema_max_decay: float = 0.9999, | ||
| 39 | sample_batch_size: int = 1, | 34 | sample_batch_size: int = 1, |
| 40 | sample_num_batches: int = 1, | 35 | sample_num_batches: int = 1, |
| 41 | sample_num_steps: int = 20, | 36 | sample_num_steps: int = 20, |
| @@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks( | |||
| 68 | image_size=sample_image_size, | 63 | image_size=sample_image_size, |
| 69 | ) | 64 | ) |
| 70 | 65 | ||
| 71 | if use_ema: | ||
| 72 | ema_embeddings = EMAModel( | ||
| 73 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
| 74 | inv_gamma=ema_inv_gamma, | ||
| 75 | power=ema_power, | ||
| 76 | max_value=ema_max_decay, | ||
| 77 | ) | ||
| 78 | ema_embeddings.to(accelerator.device) | ||
| 79 | else: | ||
| 80 | ema_embeddings = None | ||
| 81 | |||
| 82 | def ema_context(): | ||
| 83 | if ema_embeddings is not None: | ||
| 84 | return ema_embeddings.apply_temporary( | ||
| 85 | text_encoder.text_model.embeddings.overlay.parameters() | ||
| 86 | ) | ||
| 87 | else: | ||
| 88 | return nullcontext() | ||
| 89 | |||
| 90 | def on_accum_model(): | 66 | def on_accum_model(): |
| 91 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings.overlay |
| 92 | 68 | ||
| @@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks( | |||
| 98 | @contextmanager | 74 | @contextmanager |
| 99 | def on_eval(): | 75 | def on_eval(): |
| 100 | tokenizer.eval() | 76 | tokenizer.eval() |
| 101 | 77 | yield | |
| 102 | with ema_context(): | ||
| 103 | yield | ||
| 104 | |||
| 105 | @torch.no_grad() | ||
| 106 | def on_after_optimize(zero_ids, lr: float): | ||
| 107 | if ema_embeddings is not None: | ||
| 108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) | ||
| 109 | |||
| 110 | def on_log(): | ||
| 111 | if ema_embeddings is not None: | ||
| 112 | return {"ema_decay": ema_embeddings.decay} | ||
| 113 | return {} | ||
| 114 | 78 | ||
| 115 | @torch.no_grad() | 79 | @torch.no_grad() |
| 116 | def on_checkpoint(step, postfix): | 80 | def on_checkpoint(step, postfix): |
| 117 | print(f"Saving checkpoint for step {step}...") | 81 | print(f"Saving checkpoint for step {step}...") |
| 118 | 82 | ||
| 119 | with ema_context(): | 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
| 120 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 84 | text_encoder.text_model.embeddings.save_embed( |
| 121 | text_encoder.text_model.embeddings.save_embed( | 85 | ids, |
| 122 | ids, | 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
| 123 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 87 | ) |
| 124 | ) | ||
| 125 | 88 | ||
| 126 | @torch.no_grad() | 89 | @torch.no_grad() |
| 127 | def on_sample(step): | 90 | def on_sample(step): |
| 128 | with ema_context(): | 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
| 131 | 93 | ||
| 132 | orig_unet_dtype = unet_.dtype | 94 | orig_unet_dtype = unet_.dtype |
| 133 | orig_text_encoder_dtype = text_encoder_.dtype | 95 | orig_text_encoder_dtype = text_encoder_.dtype |
| 134 | 96 | ||
| 135 | unet_.to(dtype=weight_dtype) | 97 | unet_.to(dtype=weight_dtype) |
| 136 | text_encoder_.to(dtype=weight_dtype) | 98 | text_encoder_.to(dtype=weight_dtype) |
| 137 | 99 | ||
| 138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
| 139 | 101 | ||
| 140 | unet_.to(dtype=orig_unet_dtype) | 102 | unet_.to(dtype=orig_unet_dtype) |
| 141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
| 142 | 104 | ||
| 143 | del unet_ | 105 | del unet_ |
| 144 | del text_encoder_ | 106 | del text_encoder_ |
| 145 | 107 | ||
| 146 | if torch.cuda.is_available(): | 108 | if torch.cuda.is_available(): |
| 147 | torch.cuda.empty_cache() | 109 | torch.cuda.empty_cache() |
| @@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks( | |||
| 150 | on_accum_model=on_accum_model, | 112 | on_accum_model=on_accum_model, |
| 151 | on_train=on_train, | 113 | on_train=on_train, |
| 152 | on_eval=on_eval, | 114 | on_eval=on_eval, |
| 153 | on_after_optimize=on_after_optimize, | ||
| 154 | on_log=on_log, | ||
| 155 | on_checkpoint=on_checkpoint, | 115 | on_checkpoint=on_checkpoint, |
| 156 | on_sample=on_sample, | 116 | on_sample=on_sample, |
| 157 | ) | 117 | ) |
