diff options
-rw-r--r-- | models/clip/embeddings.py | 6 | ||||
-rw-r--r-- | train_dreambooth.py | 2 | ||||
-rw-r--r-- | train_ti.py | 59 | ||||
-rw-r--r-- | training/util.py | 100 |
4 files changed, 157 insertions, 10 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index fb639f1..384c795 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
88 | def save_embed(self, input_ids: list[int], filename: Path): | 88 | def save_embed(self, input_ids: list[int], filename: Path): |
89 | save_file({"embed": self.get_embed(input_ids)}, filename) | 89 | save_file({"embed": self.get_embed(input_ids)}, filename) |
90 | 90 | ||
91 | def make_permanent(self): | 91 | def persist(self): |
92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | 92 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] |
93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 93 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
94 | 94 | ||
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
96 | if isinstance(input_ids, list): | 96 | if isinstance(input_ids, list): |
97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 97 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
98 | 98 | ||
99 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
100 | |||
101 | embeds = self.token_embedding(input_ids) | 99 | embeds = self.token_embedding(input_ids) |
100 | |||
101 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | ||
102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | 102 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] |
103 | 103 | ||
104 | return embeds | 104 | return embeds |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4d1e0a3..c355ea1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -638,7 +638,7 @@ def main(): | |||
638 | if args.train_text_encoder: | 638 | if args.train_text_encoder: |
639 | print(f"Training entire text encoder.") | 639 | print(f"Training entire text encoder.") |
640 | 640 | ||
641 | embeddings.make_permanent() | 641 | embeddings.persist() |
642 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | 642 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) |
643 | else: | 643 | else: |
644 | print(f"Training added text embeddings") | 644 | print(f"Training added text embeddings") |
diff --git a/train_ti.py b/train_ti.py index 98385dd..dc36e42 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -2,6 +2,7 @@ import argparse | |||
2 | import math | 2 | import math |
3 | import datetime | 3 | import datetime |
4 | import logging | 4 | import logging |
5 | import copy | ||
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | from functools import partial | 7 | from functools import partial |
7 | 8 | ||
@@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
24 | from training.common import run_model | 25 | from training.common import run_model |
25 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
26 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
27 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
28 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
@@ -267,6 +268,27 @@ def parse_args(): | |||
267 | help="Minimum learning rate in the lr scheduler." | 268 | help="Minimum learning rate in the lr scheduler." |
268 | ) | 269 | ) |
269 | parser.add_argument( | 270 | parser.add_argument( |
271 | "--use_ema", | ||
272 | action="store_true", | ||
273 | default=True, | ||
274 | help="Whether to use EMA model." | ||
275 | ) | ||
276 | parser.add_argument( | ||
277 | "--ema_inv_gamma", | ||
278 | type=float, | ||
279 | default=1.0 | ||
280 | ) | ||
281 | parser.add_argument( | ||
282 | "--ema_power", | ||
283 | type=float, | ||
284 | default=6/7 | ||
285 | ) | ||
286 | parser.add_argument( | ||
287 | "--ema_max_decay", | ||
288 | type=float, | ||
289 | default=0.9999 | ||
290 | ) | ||
291 | parser.add_argument( | ||
270 | "--use_8bit_adam", | 292 | "--use_8bit_adam", |
271 | action="store_true", | 293 | action="store_true", |
272 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 294 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase): | |||
449 | unet, | 471 | unet, |
450 | tokenizer, | 472 | tokenizer, |
451 | text_encoder, | 473 | text_encoder, |
474 | ema_embeddings, | ||
452 | scheduler, | 475 | scheduler, |
453 | placeholder_token, | 476 | placeholder_token, |
454 | new_ids, | 477 | new_ids, |
@@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase): | |||
473 | self.unet = unet | 496 | self.unet = unet |
474 | self.tokenizer = tokenizer | 497 | self.tokenizer = tokenizer |
475 | self.text_encoder = text_encoder | 498 | self.text_encoder = text_encoder |
499 | self.ema_embeddings = ema_embeddings | ||
476 | self.scheduler = scheduler | 500 | self.scheduler = scheduler |
477 | self.placeholder_token = placeholder_token | 501 | self.placeholder_token = placeholder_token |
478 | self.new_ids = new_ids | 502 | self.new_ids = new_ids |
@@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase): | |||
486 | 510 | ||
487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
488 | 512 | ||
513 | if self.ema_embeddings is not None: | ||
514 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
515 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
516 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
517 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
518 | |||
489 | for (token, ids) in zip(self.placeholder_token, self.new_ids): | 519 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
490 | text_encoder.text_model.embeddings.save_embed( | 520 | text_encoder.text_model.embeddings.save_embed( |
491 | ids, | 521 | ids, |
492 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 522 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
493 | ) | 523 | ) |
494 | 524 | ||
525 | if self.ema_embeddings is not None: | ||
526 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
527 | |||
495 | del text_encoder | 528 | del text_encoder |
496 | 529 | ||
497 | @torch.no_grad() | 530 | @torch.no_grad() |
498 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 531 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
499 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 532 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
533 | |||
534 | if self.ema_embeddings is not None: | ||
535 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
536 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
537 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
538 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
539 | |||
500 | orig_dtype = text_encoder.dtype | 540 | orig_dtype = text_encoder.dtype |
501 | text_encoder.to(dtype=self.weight_dtype) | 541 | text_encoder.to(dtype=self.weight_dtype) |
502 | 542 | ||
@@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase): | |||
513 | 553 | ||
514 | text_encoder.to(dtype=orig_dtype) | 554 | text_encoder.to(dtype=orig_dtype) |
515 | 555 | ||
556 | if self.ema_embeddings is not None: | ||
557 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
558 | |||
516 | del text_encoder | 559 | del text_encoder |
517 | del pipeline | 560 | del pipeline |
518 | 561 | ||
@@ -567,6 +610,7 @@ def main(): | |||
567 | text_encoder.gradient_checkpointing_enable() | 610 | text_encoder.gradient_checkpointing_enable() |
568 | 611 | ||
569 | embeddings = patch_managed_embeddings(text_encoder) | 612 | embeddings = patch_managed_embeddings(text_encoder) |
613 | ema_embeddings = None | ||
570 | 614 | ||
571 | if args.embeddings_dir is not None: | 615 | if args.embeddings_dir is not None: |
572 | embeddings_dir = Path(args.embeddings_dir) | 616 | embeddings_dir = Path(args.embeddings_dir) |
@@ -592,6 +636,14 @@ def main(): | |||
592 | 636 | ||
593 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 637 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
594 | 638 | ||
639 | if args.use_ema: | ||
640 | ema_embeddings = EMAModel( | ||
641 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
642 | inv_gamma=args.ema_inv_gamma, | ||
643 | power=args.ema_power, | ||
644 | max_value=args.ema_max_decay, | ||
645 | ) | ||
646 | |||
595 | vae.requires_grad_(False) | 647 | vae.requires_grad_(False) |
596 | unet.requires_grad_(False) | 648 | unet.requires_grad_(False) |
597 | 649 | ||
@@ -788,6 +840,7 @@ def main(): | |||
788 | # Move vae and unet to device | 840 | # Move vae and unet to device |
789 | vae.to(accelerator.device, dtype=weight_dtype) | 841 | vae.to(accelerator.device, dtype=weight_dtype) |
790 | unet.to(accelerator.device, dtype=weight_dtype) | 842 | unet.to(accelerator.device, dtype=weight_dtype) |
843 | ema_embeddings.to(accelerator.device) | ||
791 | 844 | ||
792 | # Keep vae and unet in eval mode as we don't train these | 845 | # Keep vae and unet in eval mode as we don't train these |
793 | vae.eval() | 846 | vae.eval() |
@@ -883,6 +936,7 @@ def main(): | |||
883 | unet=unet, | 936 | unet=unet, |
884 | tokenizer=tokenizer, | 937 | tokenizer=tokenizer, |
885 | text_encoder=text_encoder, | 938 | text_encoder=text_encoder, |
939 | ema_embeddings=ema_embeddings, | ||
886 | scheduler=checkpoint_scheduler, | 940 | scheduler=checkpoint_scheduler, |
887 | placeholder_token=args.placeholder_token, | 941 | placeholder_token=args.placeholder_token, |
888 | new_ids=new_ids, | 942 | new_ids=new_ids, |
@@ -935,6 +989,9 @@ def main(): | |||
935 | 989 | ||
936 | # Checks if the accelerator has performed an optimization step behind the scenes | 990 | # Checks if the accelerator has performed an optimization step behind the scenes |
937 | if accelerator.sync_gradients: | 991 | if accelerator.sync_gradients: |
992 | if args.use_ema: | ||
993 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
994 | |||
938 | local_progress_bar.update(1) | 995 | local_progress_bar.update(1) |
939 | global_progress_bar.update(1) | 996 | global_progress_bar.update(1) |
940 | 997 | ||
diff --git a/training/util.py b/training/util.py index 43a55e1..93b6248 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,5 +1,6 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | ||
3 | from typing import Iterable | 4 | from typing import Iterable |
4 | 5 | ||
5 | import torch | 6 | import torch |
@@ -116,18 +117,58 @@ class CheckpointerBase: | |||
116 | del generator | 117 | del generator |
117 | 118 | ||
118 | 119 | ||
120 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | ||
119 | class EMAModel: | 121 | class EMAModel: |
120 | """ | 122 | """ |
121 | Exponential Moving Average of models weights | 123 | Exponential Moving Average of models weights |
122 | """ | 124 | """ |
123 | 125 | ||
124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | 126 | def __init__( |
127 | self, | ||
128 | parameters: Iterable[torch.nn.Parameter], | ||
129 | update_after_step=0, | ||
130 | inv_gamma=1.0, | ||
131 | power=2 / 3, | ||
132 | min_value=0.0, | ||
133 | max_value=0.9999, | ||
134 | ): | ||
135 | """ | ||
136 | @crowsonkb's notes on EMA Warmup: | ||
137 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
138 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
139 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
140 | at 215.4k steps). | ||
141 | Args: | ||
142 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
143 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
144 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
145 | """ | ||
125 | parameters = list(parameters) | 146 | parameters = list(parameters) |
126 | self.shadow_params = [p.clone().detach() for p in parameters] | 147 | self.shadow_params = [p.clone().detach() for p in parameters] |
127 | 148 | ||
128 | self.decay = decay | 149 | self.collected_params = None |
150 | |||
151 | self.update_after_step = update_after_step | ||
152 | self.inv_gamma = inv_gamma | ||
153 | self.power = power | ||
154 | self.min_value = min_value | ||
155 | self.max_value = max_value | ||
156 | |||
157 | self.decay = 0.0 | ||
129 | self.optimization_step = 0 | 158 | self.optimization_step = 0 |
130 | 159 | ||
160 | def get_decay(self, optimization_step): | ||
161 | """ | ||
162 | Compute the decay factor for the exponential moving average. | ||
163 | """ | ||
164 | step = max(0, optimization_step - self.update_after_step - 1) | ||
165 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
166 | |||
167 | if step <= 0: | ||
168 | return 0.0 | ||
169 | |||
170 | return max(self.min_value, min(value, self.max_value)) | ||
171 | |||
131 | @torch.no_grad() | 172 | @torch.no_grad() |
132 | def step(self, parameters): | 173 | def step(self, parameters): |
133 | parameters = list(parameters) | 174 | parameters = list(parameters) |
@@ -135,12 +176,12 @@ class EMAModel: | |||
135 | self.optimization_step += 1 | 176 | self.optimization_step += 1 |
136 | 177 | ||
137 | # Compute the decay factor for the exponential moving average. | 178 | # Compute the decay factor for the exponential moving average. |
138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | 179 | self.decay = self.get_decay(self.optimization_step) |
139 | one_minus_decay = 1 - min(self.decay, value) | ||
140 | 180 | ||
141 | for s_param, param in zip(self.shadow_params, parameters): | 181 | for s_param, param in zip(self.shadow_params, parameters): |
142 | if param.requires_grad: | 182 | if param.requires_grad: |
143 | s_param.sub_(one_minus_decay * (s_param - param)) | 183 | s_param.mul_(self.decay) |
184 | s_param.add_(param.data, alpha=1 - self.decay) | ||
144 | else: | 185 | else: |
145 | s_param.copy_(param) | 186 | s_param.copy_(param) |
146 | 187 | ||
@@ -169,3 +210,52 @@ class EMAModel: | |||
169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | 210 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) |
170 | for p in self.shadow_params | 211 | for p in self.shadow_params |
171 | ] | 212 | ] |
213 | |||
214 | def state_dict(self) -> dict: | ||
215 | r""" | ||
216 | Returns the state of the ExponentialMovingAverage as a dict. | ||
217 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
218 | """ | ||
219 | # Following PyTorch conventions, references to tensors are returned: | ||
220 | # "returns a reference to the state and not its copy!" - | ||
221 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
222 | return { | ||
223 | "decay": self.decay, | ||
224 | "optimization_step": self.optimization_step, | ||
225 | "shadow_params": self.shadow_params, | ||
226 | "collected_params": self.collected_params, | ||
227 | } | ||
228 | |||
229 | def load_state_dict(self, state_dict: dict) -> None: | ||
230 | r""" | ||
231 | Loads the ExponentialMovingAverage state. | ||
232 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
233 | Args: | ||
234 | state_dict (dict): EMA state. Should be an object returned | ||
235 | from a call to :meth:`state_dict`. | ||
236 | """ | ||
237 | # deepcopy, to be consistent with module API | ||
238 | state_dict = copy.deepcopy(state_dict) | ||
239 | |||
240 | self.decay = state_dict["decay"] | ||
241 | if self.decay < 0.0 or self.decay > 1.0: | ||
242 | raise ValueError("Decay must be between 0 and 1") | ||
243 | |||
244 | self.optimization_step = state_dict["optimization_step"] | ||
245 | if not isinstance(self.optimization_step, int): | ||
246 | raise ValueError("Invalid optimization_step") | ||
247 | |||
248 | self.shadow_params = state_dict["shadow_params"] | ||
249 | if not isinstance(self.shadow_params, list): | ||
250 | raise ValueError("shadow_params must be a list") | ||
251 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
252 | raise ValueError("shadow_params must all be Tensors") | ||
253 | |||
254 | self.collected_params = state_dict["collected_params"] | ||
255 | if self.collected_params is not None: | ||
256 | if not isinstance(self.collected_params, list): | ||
257 | raise ValueError("collected_params must be a list") | ||
258 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
259 | raise ValueError("collected_params must all be Tensors") | ||
260 | if len(self.collected_params) != len(self.shadow_params): | ||
261 | raise ValueError("collected_params and shadow_params must have the same length") | ||