diff options
| -rw-r--r-- | models/clip/embeddings.py | 6 | ||||
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 59 | ||||
| -rw-r--r-- | training/util.py | 100 |
4 files changed, 157 insertions, 10 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 88 | def save_embed(self, input_ids: list[int], filename: Path): | 88 | def save_embed(self, input_ids: list[int], filename: Path): |
| 89 | save_file({"embed": self.get_embed(input_ids)}, filename) | 89 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 90 | 90 | ||
| 91 | def make_permanent(self): | 91 | def persist(self): |
| 92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
| 93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
| 94 | 94 | ||
| @@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 96 | if isinstance(input_ids, list): | 96 | if isinstance(input_ids, list): |
| 97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
| 98 | 98 | ||
| 99 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
| 100 | |||
| 101 | embeds = self.token_embedding(input_ids) | 99 | embeds = self.token_embedding(input_ids) |
| 100 | |||
| 101 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
| 102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
| 103 | 103 | ||
| 104 | return embeds | 104 | return embeds |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4d1e0a3..c355ea1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -638,7 +638,7 @@ def main(): | |||
| 638 | if args.train_text_encoder: | 638 | if args.train_text_encoder: |
| 639 | print(f"Training entire text encoder.") | 639 | print(f"Training entire text encoder.") |
| 640 | 640 | ||
| 641 | embeddings.make_permanent() | 641 | embeddings.persist() |
| 642 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | 642 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) |
| 643 | else: | 643 | else: |
| 644 | print(f"Training added text embeddings") | 644 | print(f"Training added text embeddings") |
diff --git a/train_ti.py b/train_ti.py index 98385dd..dc36e42 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -2,6 +2,7 @@ import argparse | |||
| 2 | import math | 2 | import math |
| 3 | import datetime | 3 | import datetime |
| 4 | import logging | 4 | import logging |
| 5 | import copy | ||
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | from functools import partial | 7 | from functools import partial |
| 7 | 8 | ||
| @@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
| 24 | from training.common import run_model | 25 | from training.common import run_model |
| 25 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 26 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 27 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| 28 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| 29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| 30 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
| @@ -267,6 +268,27 @@ def parse_args(): | |||
| 267 | help="Minimum learning rate in the lr scheduler." | 268 | help="Minimum learning rate in the lr scheduler." |
| 268 | ) | 269 | ) |
| 269 | parser.add_argument( | 270 | parser.add_argument( |
| 271 | "--use_ema", | ||
| 272 | action="store_true", | ||
| 273 | default=True, | ||
| 274 | help="Whether to use EMA model." | ||
| 275 | ) | ||
| 276 | parser.add_argument( | ||
| 277 | "--ema_inv_gamma", | ||
| 278 | type=float, | ||
| 279 | default=1.0 | ||
| 280 | ) | ||
| 281 | parser.add_argument( | ||
| 282 | "--ema_power", | ||
| 283 | type=float, | ||
| 284 | default=6/7 | ||
| 285 | ) | ||
| 286 | parser.add_argument( | ||
| 287 | "--ema_max_decay", | ||
| 288 | type=float, | ||
| 289 | default=0.9999 | ||
| 290 | ) | ||
| 291 | parser.add_argument( | ||
| 270 | "--use_8bit_adam", | 292 | "--use_8bit_adam", |
| 271 | action="store_true", | 293 | action="store_true", |
| 272 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 294 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase): | |||
| 449 | unet, | 471 | unet, |
| 450 | tokenizer, | 472 | tokenizer, |
| 451 | text_encoder, | 473 | text_encoder, |
| 474 | ema_embeddings, | ||
| 452 | scheduler, | 475 | scheduler, |
| 453 | placeholder_token, | 476 | placeholder_token, |
| 454 | new_ids, | 477 | new_ids, |
| @@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase): | |||
| 473 | self.unet = unet | 496 | self.unet = unet |
| 474 | self.tokenizer = tokenizer | 497 | self.tokenizer = tokenizer |
| 475 | self.text_encoder = text_encoder | 498 | self.text_encoder = text_encoder |
| 499 | self.ema_embeddings = ema_embeddings | ||
| 476 | self.scheduler = scheduler | 500 | self.scheduler = scheduler |
| 477 | self.placeholder_token = placeholder_token | 501 | self.placeholder_token = placeholder_token |
| 478 | self.new_ids = new_ids | 502 | self.new_ids = new_ids |
| @@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase): | |||
| 486 | 510 | ||
| 487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 488 | 512 | ||
| 513 | if self.ema_embeddings is not None: | ||
| 514 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
| 515 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
| 516 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
| 517 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
| 518 | |||
| 489 | for (token, ids) in zip(self.placeholder_token, self.new_ids): | 519 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
| 490 | text_encoder.text_model.embeddings.save_embed( | 520 | text_encoder.text_model.embeddings.save_embed( |
| 491 | ids, | 521 | ids, |
| 492 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 522 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 493 | ) | 523 | ) |
| 494 | 524 | ||
| 525 | if self.ema_embeddings is not None: | ||
| 526 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
| 527 | |||
| 495 | del text_encoder | 528 | del text_encoder |
| 496 | 529 | ||
| 497 | @torch.no_grad() | 530 | @torch.no_grad() |
| 498 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 531 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 499 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 532 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 533 | |||
| 534 | if self.ema_embeddings is not None: | ||
| 535 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
| 536 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
| 537 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
| 538 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
| 539 | |||
| 500 | orig_dtype = text_encoder.dtype | 540 | orig_dtype = text_encoder.dtype |
| 501 | text_encoder.to(dtype=self.weight_dtype) | 541 | text_encoder.to(dtype=self.weight_dtype) |
| 502 | 542 | ||
| @@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase): | |||
| 513 | 553 | ||
| 514 | text_encoder.to(dtype=orig_dtype) | 554 | text_encoder.to(dtype=orig_dtype) |
| 515 | 555 | ||
| 556 | if self.ema_embeddings is not None: | ||
| 557 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
| 558 | |||
| 516 | del text_encoder | 559 | del text_encoder |
| 517 | del pipeline | 560 | del pipeline |
| 518 | 561 | ||
| @@ -567,6 +610,7 @@ def main(): | |||
| 567 | text_encoder.gradient_checkpointing_enable() | 610 | text_encoder.gradient_checkpointing_enable() |
| 568 | 611 | ||
| 569 | embeddings = patch_managed_embeddings(text_encoder) | 612 | embeddings = patch_managed_embeddings(text_encoder) |
| 613 | ema_embeddings = None | ||
| 570 | 614 | ||
| 571 | if args.embeddings_dir is not None: | 615 | if args.embeddings_dir is not None: |
| 572 | embeddings_dir = Path(args.embeddings_dir) | 616 | embeddings_dir = Path(args.embeddings_dir) |
| @@ -592,6 +636,14 @@ def main(): | |||
| 592 | 636 | ||
| 593 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 637 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
| 594 | 638 | ||
| 639 | if args.use_ema: | ||
| 640 | ema_embeddings = EMAModel( | ||
| 641 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 642 | inv_gamma=args.ema_inv_gamma, | ||
| 643 | power=args.ema_power, | ||
| 644 | max_value=args.ema_max_decay, | ||
| 645 | ) | ||
| 646 | |||
| 595 | vae.requires_grad_(False) | 647 | vae.requires_grad_(False) |
| 596 | unet.requires_grad_(False) | 648 | unet.requires_grad_(False) |
| 597 | 649 | ||
| @@ -788,6 +840,7 @@ def main(): | |||
| 788 | # Move vae and unet to device | 840 | # Move vae and unet to device |
| 789 | vae.to(accelerator.device, dtype=weight_dtype) | 841 | vae.to(accelerator.device, dtype=weight_dtype) |
| 790 | unet.to(accelerator.device, dtype=weight_dtype) | 842 | unet.to(accelerator.device, dtype=weight_dtype) |
| 843 | ema_embeddings.to(accelerator.device) | ||
| 791 | 844 | ||
| 792 | # Keep vae and unet in eval mode as we don't train these | 845 | # Keep vae and unet in eval mode as we don't train these |
| 793 | vae.eval() | 846 | vae.eval() |
| @@ -883,6 +936,7 @@ def main(): | |||
| 883 | unet=unet, | 936 | unet=unet, |
| 884 | tokenizer=tokenizer, | 937 | tokenizer=tokenizer, |
| 885 | text_encoder=text_encoder, | 938 | text_encoder=text_encoder, |
| 939 | ema_embeddings=ema_embeddings, | ||
| 886 | scheduler=checkpoint_scheduler, | 940 | scheduler=checkpoint_scheduler, |
| 887 | placeholder_token=args.placeholder_token, | 941 | placeholder_token=args.placeholder_token, |
| 888 | new_ids=new_ids, | 942 | new_ids=new_ids, |
| @@ -935,6 +989,9 @@ def main(): | |||
| 935 | 989 | ||
| 936 | # Checks if the accelerator has performed an optimization step behind the scenes | 990 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 937 | if accelerator.sync_gradients: | 991 | if accelerator.sync_gradients: |
| 992 | if args.use_ema: | ||
| 993 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 994 | |||
| 938 | local_progress_bar.update(1) | 995 | local_progress_bar.update(1) |
| 939 | global_progress_bar.update(1) | 996 | global_progress_bar.update(1) |
| 940 | 997 | ||
diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | ||
| 3 | from typing import Iterable | 4 | from typing import Iterable |
| 4 | 5 | ||
| 5 | import torch | 6 | import torch |
| @@ -116,18 +117,58 @@ class CheckpointerBase: | |||
| 116 | del generator | 117 | del generator |
| 117 | 118 | ||
| 118 | 119 | ||
| 120 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | ||
| 119 | class EMAModel: | 121 | class EMAModel: |
| 120 | """ | 122 | """ |
| 121 | Exponential Moving Average of models weights | 123 | Exponential Moving Average of models weights |
| 122 | """ | 124 | """ |
| 123 | 125 | ||
| 124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | 126 | def __init__( |
| 127 | self, | ||
| 128 | parameters: Iterable[torch.nn.Parameter], | ||
| 129 | update_after_step=0, | ||
| 130 | inv_gamma=1.0, | ||
| 131 | power=2 / 3, | ||
| 132 | min_value=0.0, | ||
| 133 | max_value=0.9999, | ||
| 134 | ): | ||
| 135 | """ | ||
| 136 | @crowsonkb's notes on EMA Warmup: | ||
| 137 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
| 138 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
| 139 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
| 140 | at 215.4k steps). | ||
| 141 | Args: | ||
| 142 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
| 143 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
| 144 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
| 145 | """ | ||
| 125 | parameters = list(parameters) | 146 | parameters = list(parameters) |
| 126 | self.shadow_params = [p.clone().detach() for p in parameters] | 147 | self.shadow_params = [p.clone().detach() for p in parameters] |
| 127 | 148 | ||
| 128 | self.decay = decay | 149 | self.collected_params = None |
| 150 | |||
| 151 | self.update_after_step = update_after_step | ||
| 152 | self.inv_gamma = inv_gamma | ||
| 153 | self.power = power | ||
| 154 | self.min_value = min_value | ||
| 155 | self.max_value = max_value | ||
| 156 | |||
| 157 | self.decay = 0.0 | ||
| 129 | self.optimization_step = 0 | 158 | self.optimization_step = 0 |
| 130 | 159 | ||
| 160 | def get_decay(self, optimization_step): | ||
| 161 | """ | ||
| 162 | Compute the decay factor for the exponential moving average. | ||
| 163 | """ | ||
| 164 | step = max(0, optimization_step - self.update_after_step - 1) | ||
| 165 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
| 166 | |||
| 167 | if step <= 0: | ||
| 168 | return 0.0 | ||
| 169 | |||
| 170 | return max(self.min_value, min(value, self.max_value)) | ||
| 171 | |||
| 131 | @torch.no_grad() | 172 | @torch.no_grad() |
| 132 | def step(self, parameters): | 173 | def step(self, parameters): |
| 133 | parameters = list(parameters) | 174 | parameters = list(parameters) |
| @@ -135,12 +176,12 @@ class EMAModel: | |||
| 135 | self.optimization_step += 1 | 176 | self.optimization_step += 1 |
| 136 | 177 | ||
| 137 | # Compute the decay factor for the exponential moving average. | 178 | # Compute the decay factor for the exponential moving average. |
| 138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | 179 | self.decay = self.get_decay(self.optimization_step) |
| 139 | one_minus_decay = 1 - min(self.decay, value) | ||
| 140 | 180 | ||
| 141 | for s_param, param in zip(self.shadow_params, parameters): | 181 | for s_param, param in zip(self.shadow_params, parameters): |
| 142 | if param.requires_grad: | 182 | if param.requires_grad: |
| 143 | s_param.sub_(one_minus_decay * (s_param - param)) | 183 | s_param.mul_(self.decay) |
| 184 | s_param.add_(param.data, alpha=1 - self.decay) | ||
| 144 | else: | 185 | else: |
| 145 | s_param.copy_(param) | 186 | s_param.copy_(param) |
| 146 | 187 | ||
| @@ -169,3 +210,52 @@ class EMAModel: | |||
| 169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | 210 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
| 170 | for p in self.shadow_params | 211 | for p in self.shadow_params |
| 171 | ] | 212 | ] |
| 213 | |||
| 214 | def state_dict(self) -> dict: | ||
| 215 | r""" | ||
| 216 | Returns the state of the ExponentialMovingAverage as a dict. | ||
| 217 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 218 | """ | ||
| 219 | # Following PyTorch conventions, references to tensors are returned: | ||
| 220 | # "returns a reference to the state and not its copy!" - | ||
| 221 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
| 222 | return { | ||
| 223 | "decay": self.decay, | ||
| 224 | "optimization_step": self.optimization_step, | ||
| 225 | "shadow_params": self.shadow_params, | ||
| 226 | "collected_params": self.collected_params, | ||
| 227 | } | ||
| 228 | |||
| 229 | def load_state_dict(self, state_dict: dict) -> None: | ||
| 230 | r""" | ||
| 231 | Loads the ExponentialMovingAverage state. | ||
| 232 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 233 | Args: | ||
| 234 | state_dict (dict): EMA state. Should be an object returned | ||
| 235 | from a call to :meth:`state_dict`. | ||
| 236 | """ | ||
| 237 | # deepcopy, to be consistent with module API | ||
| 238 | state_dict = copy.deepcopy(state_dict) | ||
| 239 | |||
| 240 | self.decay = state_dict["decay"] | ||
| 241 | if self.decay < 0.0 or self.decay > 1.0: | ||
| 242 | raise ValueError("Decay must be between 0 and 1") | ||
| 243 | |||
| 244 | self.optimization_step = state_dict["optimization_step"] | ||
| 245 | if not isinstance(self.optimization_step, int): | ||
| 246 | raise ValueError("Invalid optimization_step") | ||
| 247 | |||
| 248 | self.shadow_params = state_dict["shadow_params"] | ||
| 249 | if not isinstance(self.shadow_params, list): | ||
| 250 | raise ValueError("shadow_params must be a list") | ||
| 251 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
| 252 | raise ValueError("shadow_params must all be Tensors") | ||
| 253 | |||
| 254 | self.collected_params = state_dict["collected_params"] | ||
| 255 | if self.collected_params is not None: | ||
| 256 | if not isinstance(self.collected_params, list): | ||
| 257 | raise ValueError("collected_params must be a list") | ||
| 258 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
| 259 | raise ValueError("collected_params must all be Tensors") | ||
| 260 | if len(self.collected_params) != len(self.shadow_params): | ||
| 261 | raise ValueError("collected_params and shadow_params must have the same length") | ||
