summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py6
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_ti.py59
-rw-r--r--training/util.py100
4 files changed, 157 insertions, 10 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index fb639f1..384c795 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -88,7 +88,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
88 def save_embed(self, input_ids: list[int], filename: Path): 88 def save_embed(self, input_ids: list[int], filename: Path):
89 save_file({"embed": self.get_embed(input_ids)}, filename) 89 save_file({"embed": self.get_embed(input_ids)}, filename)
90 90
91 def make_permanent(self): 91 def persist(self):
92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] 92 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
93 self.temp_token_ids = torch.tensor([], dtype=torch.long) 93 self.temp_token_ids = torch.tensor([], dtype=torch.long)
94 94
@@ -96,9 +96,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
96 if isinstance(input_ids, list): 96 if isinstance(input_ids, list):
97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 97 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
98 98
99 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
100
101 embeds = self.token_embedding(input_ids) 99 embeds = self.token_embedding(input_ids)
100
101 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
102 embeds[mask] = self.temp_token_embedding(input_ids)[mask] 102 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
103 103
104 return embeds 104 return embeds
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4d1e0a3..c355ea1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -638,7 +638,7 @@ def main():
638 if args.train_text_encoder: 638 if args.train_text_encoder:
639 print(f"Training entire text encoder.") 639 print(f"Training entire text encoder.")
640 640
641 embeddings.make_permanent() 641 embeddings.persist()
642 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) 642 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
643 else: 643 else:
644 print(f"Training added text embeddings") 644 print(f"Training added text embeddings")
diff --git a/train_ti.py b/train_ti.py
index 98385dd..dc36e42 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -2,6 +2,7 @@ import argparse
2import math 2import math
3import datetime 3import datetime
4import logging 4import logging
5import copy
5from pathlib import Path 6from pathlib import Path
6from functools import partial 7from functools import partial
7 8
@@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem
24from training.common import run_model 25from training.common import run_model
25from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 27from training.lr import LRFinder
27from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
28from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
30from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
@@ -267,6 +268,27 @@ def parse_args():
267 help="Minimum learning rate in the lr scheduler." 268 help="Minimum learning rate in the lr scheduler."
268 ) 269 )
269 parser.add_argument( 270 parser.add_argument(
271 "--use_ema",
272 action="store_true",
273 default=True,
274 help="Whether to use EMA model."
275 )
276 parser.add_argument(
277 "--ema_inv_gamma",
278 type=float,
279 default=1.0
280 )
281 parser.add_argument(
282 "--ema_power",
283 type=float,
284 default=6/7
285 )
286 parser.add_argument(
287 "--ema_max_decay",
288 type=float,
289 default=0.9999
290 )
291 parser.add_argument(
270 "--use_8bit_adam", 292 "--use_8bit_adam",
271 action="store_true", 293 action="store_true",
272 help="Whether or not to use 8-bit Adam from bitsandbytes." 294 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase):
449 unet, 471 unet,
450 tokenizer, 472 tokenizer,
451 text_encoder, 473 text_encoder,
474 ema_embeddings,
452 scheduler, 475 scheduler,
453 placeholder_token, 476 placeholder_token,
454 new_ids, 477 new_ids,
@@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase):
473 self.unet = unet 496 self.unet = unet
474 self.tokenizer = tokenizer 497 self.tokenizer = tokenizer
475 self.text_encoder = text_encoder 498 self.text_encoder = text_encoder
499 self.ema_embeddings = ema_embeddings
476 self.scheduler = scheduler 500 self.scheduler = scheduler
477 self.placeholder_token = placeholder_token 501 self.placeholder_token = placeholder_token
478 self.new_ids = new_ids 502 self.new_ids = new_ids
@@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase):
486 510
487 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 511 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
488 512
513 if self.ema_embeddings is not None:
514 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding
515 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
516 self.ema_embeddings.copy_to(ema_weights.parameters())
517 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
518
489 for (token, ids) in zip(self.placeholder_token, self.new_ids): 519 for (token, ids) in zip(self.placeholder_token, self.new_ids):
490 text_encoder.text_model.embeddings.save_embed( 520 text_encoder.text_model.embeddings.save_embed(
491 ids, 521 ids,
492 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 522 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
493 ) 523 )
494 524
525 if self.ema_embeddings is not None:
526 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
527
495 del text_encoder 528 del text_encoder
496 529
497 @torch.no_grad() 530 @torch.no_grad()
498 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 531 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
499 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 532 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
533
534 if self.ema_embeddings is not None:
535 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding
536 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
537 self.ema_embeddings.copy_to(ema_weights.parameters())
538 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
539
500 orig_dtype = text_encoder.dtype 540 orig_dtype = text_encoder.dtype
501 text_encoder.to(dtype=self.weight_dtype) 541 text_encoder.to(dtype=self.weight_dtype)
502 542
@@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase):
513 553
514 text_encoder.to(dtype=orig_dtype) 554 text_encoder.to(dtype=orig_dtype)
515 555
556 if self.ema_embeddings is not None:
557 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
558
516 del text_encoder 559 del text_encoder
517 del pipeline 560 del pipeline
518 561
@@ -567,6 +610,7 @@ def main():
567 text_encoder.gradient_checkpointing_enable() 610 text_encoder.gradient_checkpointing_enable()
568 611
569 embeddings = patch_managed_embeddings(text_encoder) 612 embeddings = patch_managed_embeddings(text_encoder)
613 ema_embeddings = None
570 614
571 if args.embeddings_dir is not None: 615 if args.embeddings_dir is not None:
572 embeddings_dir = Path(args.embeddings_dir) 616 embeddings_dir = Path(args.embeddings_dir)
@@ -592,6 +636,14 @@ def main():
592 636
593 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") 637 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
594 638
639 if args.use_ema:
640 ema_embeddings = EMAModel(
641 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
642 inv_gamma=args.ema_inv_gamma,
643 power=args.ema_power,
644 max_value=args.ema_max_decay,
645 )
646
595 vae.requires_grad_(False) 647 vae.requires_grad_(False)
596 unet.requires_grad_(False) 648 unet.requires_grad_(False)
597 649
@@ -788,6 +840,7 @@ def main():
788 # Move vae and unet to device 840 # Move vae and unet to device
789 vae.to(accelerator.device, dtype=weight_dtype) 841 vae.to(accelerator.device, dtype=weight_dtype)
790 unet.to(accelerator.device, dtype=weight_dtype) 842 unet.to(accelerator.device, dtype=weight_dtype)
843 ema_embeddings.to(accelerator.device)
791 844
792 # Keep vae and unet in eval mode as we don't train these 845 # Keep vae and unet in eval mode as we don't train these
793 vae.eval() 846 vae.eval()
@@ -883,6 +936,7 @@ def main():
883 unet=unet, 936 unet=unet,
884 tokenizer=tokenizer, 937 tokenizer=tokenizer,
885 text_encoder=text_encoder, 938 text_encoder=text_encoder,
939 ema_embeddings=ema_embeddings,
886 scheduler=checkpoint_scheduler, 940 scheduler=checkpoint_scheduler,
887 placeholder_token=args.placeholder_token, 941 placeholder_token=args.placeholder_token,
888 new_ids=new_ids, 942 new_ids=new_ids,
@@ -935,6 +989,9 @@ def main():
935 989
936 # Checks if the accelerator has performed an optimization step behind the scenes 990 # Checks if the accelerator has performed an optimization step behind the scenes
937 if accelerator.sync_gradients: 991 if accelerator.sync_gradients:
992 if args.use_ema:
993 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
994
938 local_progress_bar.update(1) 995 local_progress_bar.update(1)
939 global_progress_bar.update(1) 996 global_progress_bar.update(1)
940 997
diff --git a/training/util.py b/training/util.py
index 43a55e1..93b6248 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy
3from typing import Iterable 4from typing import Iterable
4 5
5import torch 6import torch
@@ -116,18 +117,58 @@ class CheckpointerBase:
116 del generator 117 del generator
117 118
118 119
120# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
119class EMAModel: 121class EMAModel:
120 """ 122 """
121 Exponential Moving Average of models weights 123 Exponential Moving Average of models weights
122 """ 124 """
123 125
124 def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 126 def __init__(
127 self,
128 parameters: Iterable[torch.nn.Parameter],
129 update_after_step=0,
130 inv_gamma=1.0,
131 power=2 / 3,
132 min_value=0.0,
133 max_value=0.9999,
134 ):
135 """
136 @crowsonkb's notes on EMA Warmup:
137 If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
138 to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
139 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
140 at 215.4k steps).
141 Args:
142 inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
143 power (float): Exponential factor of EMA warmup. Default: 2/3.
144 min_value (float): The minimum EMA decay rate. Default: 0.
145 """
125 parameters = list(parameters) 146 parameters = list(parameters)
126 self.shadow_params = [p.clone().detach() for p in parameters] 147 self.shadow_params = [p.clone().detach() for p in parameters]
127 148
128 self.decay = decay 149 self.collected_params = None
150
151 self.update_after_step = update_after_step
152 self.inv_gamma = inv_gamma
153 self.power = power
154 self.min_value = min_value
155 self.max_value = max_value
156
157 self.decay = 0.0
129 self.optimization_step = 0 158 self.optimization_step = 0
130 159
160 def get_decay(self, optimization_step):
161 """
162 Compute the decay factor for the exponential moving average.
163 """
164 step = max(0, optimization_step - self.update_after_step - 1)
165 value = 1 - (1 + step / self.inv_gamma) ** -self.power
166
167 if step <= 0:
168 return 0.0
169
170 return max(self.min_value, min(value, self.max_value))
171
131 @torch.no_grad() 172 @torch.no_grad()
132 def step(self, parameters): 173 def step(self, parameters):
133 parameters = list(parameters) 174 parameters = list(parameters)
@@ -135,12 +176,12 @@ class EMAModel:
135 self.optimization_step += 1 176 self.optimization_step += 1
136 177
137 # Compute the decay factor for the exponential moving average. 178 # Compute the decay factor for the exponential moving average.
138 value = (1 + self.optimization_step) / (10 + self.optimization_step) 179 self.decay = self.get_decay(self.optimization_step)
139 one_minus_decay = 1 - min(self.decay, value)
140 180
141 for s_param, param in zip(self.shadow_params, parameters): 181 for s_param, param in zip(self.shadow_params, parameters):
142 if param.requires_grad: 182 if param.requires_grad:
143 s_param.sub_(one_minus_decay * (s_param - param)) 183 s_param.mul_(self.decay)
184 s_param.add_(param.data, alpha=1 - self.decay)
144 else: 185 else:
145 s_param.copy_(param) 186 s_param.copy_(param)
146 187
@@ -169,3 +210,52 @@ class EMAModel:
169 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 210 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
170 for p in self.shadow_params 211 for p in self.shadow_params
171 ] 212 ]
213
214 def state_dict(self) -> dict:
215 r"""
216 Returns the state of the ExponentialMovingAverage as a dict.
217 This method is used by accelerate during checkpointing to save the ema state dict.
218 """
219 # Following PyTorch conventions, references to tensors are returned:
220 # "returns a reference to the state and not its copy!" -
221 # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
222 return {
223 "decay": self.decay,
224 "optimization_step": self.optimization_step,
225 "shadow_params": self.shadow_params,
226 "collected_params": self.collected_params,
227 }
228
229 def load_state_dict(self, state_dict: dict) -> None:
230 r"""
231 Loads the ExponentialMovingAverage state.
232 This method is used by accelerate during checkpointing to save the ema state dict.
233 Args:
234 state_dict (dict): EMA state. Should be an object returned
235 from a call to :meth:`state_dict`.
236 """
237 # deepcopy, to be consistent with module API
238 state_dict = copy.deepcopy(state_dict)
239
240 self.decay = state_dict["decay"]
241 if self.decay < 0.0 or self.decay > 1.0:
242 raise ValueError("Decay must be between 0 and 1")
243
244 self.optimization_step = state_dict["optimization_step"]
245 if not isinstance(self.optimization_step, int):
246 raise ValueError("Invalid optimization_step")
247
248 self.shadow_params = state_dict["shadow_params"]
249 if not isinstance(self.shadow_params, list):
250 raise ValueError("shadow_params must be a list")
251 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
252 raise ValueError("shadow_params must all be Tensors")
253
254 self.collected_params = state_dict["collected_params"]
255 if self.collected_params is not None:
256 if not isinstance(self.collected_params, list):
257 raise ValueError("collected_params must be a list")
258 if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
259 raise ValueError("collected_params must all be Tensors")
260 if len(self.collected_params) != len(self.shadow_params):
261 raise ValueError("collected_params and shadow_params must have the same length")