summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml2
-rw-r--r--train_dreambooth.py10
-rw-r--r--train_ti.py21
-rw-r--r--training/functional.py35
-rw-r--r--training/lr.py266
-rw-r--r--training/optimization.py19
-rw-r--r--training/strategy/dreambooth.py4
-rw-r--r--training/strategy/ti.py5
-rw-r--r--training/util.py146
9 files changed, 111 insertions, 397 deletions
diff --git a/environment.yaml b/environment.yaml
index 03345c6..c992759 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -25,4 +25,4 @@ dependencies:
25 - test-tube>=0.7.5 25 - test-tube>=0.7.5
26 - transformers==4.25.1 26 - transformers==4.25.1
27 - triton==2.0.0.dev20221202 27 - triton==2.0.0.dev20221202
28 - xformers==0.0.16rc403 28 - xformers==0.0.16.dev430
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9c1e41c..a70c80e 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -16,6 +16,7 @@ from slugify import slugify
16from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, keyword_filter 17from data.csv import VlpnDataModule, keyword_filter
18from training.functional import train, get_models 18from training.functional import train, get_models
19from training.lr import plot_metrics
19from training.strategy.dreambooth import dreambooth_strategy 20from training.strategy.dreambooth import dreambooth_strategy
20from training.optimization import get_scheduler 21from training.optimization import get_scheduler
21from training.util import save_args 22from training.util import save_args
@@ -524,6 +525,10 @@ def main():
524 args.train_batch_size * accelerator.num_processes 525 args.train_batch_size * accelerator.num_processes
525 ) 526 )
526 527
528 if args.find_lr:
529 args.learning_rate = 1e-6
530 args.lr_scheduler = "exponential_growth"
531
527 if args.use_8bit_adam: 532 if args.use_8bit_adam:
528 try: 533 try:
529 import bitsandbytes as bnb 534 import bitsandbytes as bnb
@@ -602,11 +607,12 @@ def main():
602 warmup_exp=args.lr_warmup_exp, 607 warmup_exp=args.lr_warmup_exp,
603 annealing_exp=args.lr_annealing_exp, 608 annealing_exp=args.lr_annealing_exp,
604 cycles=args.lr_cycles, 609 cycles=args.lr_cycles,
610 end_lr=1e2,
605 train_epochs=args.num_train_epochs, 611 train_epochs=args.num_train_epochs,
606 warmup_epochs=args.lr_warmup_epochs, 612 warmup_epochs=args.lr_warmup_epochs,
607 ) 613 )
608 614
609 trainer( 615 metrics = trainer(
610 strategy=dreambooth_strategy, 616 strategy=dreambooth_strategy,
611 project="dreambooth", 617 project="dreambooth",
612 train_dataloader=datamodule.train_dataloader, 618 train_dataloader=datamodule.train_dataloader,
@@ -634,6 +640,8 @@ def main():
634 sample_image_size=args.sample_image_size, 640 sample_image_size=args.sample_image_size,
635 ) 641 )
636 642
643 plot_metrics(metrics, output_dir.joinpath("lr.png"))
644
637 645
638if __name__ == "__main__": 646if __name__ == "__main__":
639 main() 647 main()
diff --git a/train_ti.py b/train_ti.py
index 451b61b..c118aab 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -15,6 +15,7 @@ from slugify import slugify
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, add_placeholder_tokens, get_models 17from training.functional import train, add_placeholder_tokens, get_models
18from training.lr import plot_metrics
18from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 20from training.optimization import get_scheduler
20from training.util import save_args 21from training.util import save_args
@@ -61,6 +62,12 @@ def parse_args():
61 help="The name of the current project.", 62 help="The name of the current project.",
62 ) 63 )
63 parser.add_argument( 64 parser.add_argument(
65 "--skip_first",
66 type=int,
67 default=0,
68 help="Tokens to skip training for.",
69 )
70 parser.add_argument(
64 "--placeholder_tokens", 71 "--placeholder_tokens",
65 type=str, 72 type=str,
66 nargs='*', 73 nargs='*',
@@ -407,7 +414,7 @@ def parse_args():
407 ) 414 )
408 parser.add_argument( 415 parser.add_argument(
409 "--emb_decay", 416 "--emb_decay",
410 default=10, 417 default=1e0,
411 type=float, 418 type=float,
412 help="Embedding decay factor." 419 help="Embedding decay factor."
413 ) 420 )
@@ -543,6 +550,10 @@ def main():
543 args.train_batch_size * accelerator.num_processes 550 args.train_batch_size * accelerator.num_processes
544 ) 551 )
545 552
553 if args.find_lr:
554 args.learning_rate = 1e-5
555 args.lr_scheduler = "exponential_growth"
556
546 if args.use_8bit_adam: 557 if args.use_8bit_adam:
547 try: 558 try:
548 import bitsandbytes as bnb 559 import bitsandbytes as bnb
@@ -596,6 +607,9 @@ def main():
596 ) 607 )
597 608
598 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): 609 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
610 if i < args.skip_first:
611 return
612
599 if len(placeholder_tokens) == 1: 613 if len(placeholder_tokens) == 1:
600 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}")
601 else: 615 else:
@@ -656,11 +670,12 @@ def main():
656 warmup_exp=args.lr_warmup_exp, 670 warmup_exp=args.lr_warmup_exp,
657 annealing_exp=args.lr_annealing_exp, 671 annealing_exp=args.lr_annealing_exp,
658 cycles=args.lr_cycles, 672 cycles=args.lr_cycles,
673 end_lr=1e3,
659 train_epochs=args.num_train_epochs, 674 train_epochs=args.num_train_epochs,
660 warmup_epochs=args.lr_warmup_epochs, 675 warmup_epochs=args.lr_warmup_epochs,
661 ) 676 )
662 677
663 trainer( 678 metrics = trainer(
664 project="textual_inversion", 679 project="textual_inversion",
665 train_dataloader=datamodule.train_dataloader, 680 train_dataloader=datamodule.train_dataloader,
666 val_dataloader=datamodule.val_dataloader, 681 val_dataloader=datamodule.val_dataloader,
@@ -672,6 +687,8 @@ def main():
672 placeholder_token_ids=placeholder_token_ids, 687 placeholder_token_ids=placeholder_token_ids,
673 ) 688 )
674 689
690 plot_metrics(metrics, output_dir.joinpath("lr.png"))
691
675 if args.simultaneous: 692 if args.simultaneous:
676 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 693 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
677 else: 694 else:
diff --git a/training/functional.py b/training/functional.py
index fb135c4..c373ac9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -7,7 +7,6 @@ from pathlib import Path
7import itertools 7import itertools
8 8
9import torch 9import torch
10import torch.nn as nn
11import torch.nn.functional as F 10import torch.nn.functional as F
12from torch.utils.data import DataLoader 11from torch.utils.data import DataLoader
13 12
@@ -373,8 +372,12 @@ def train_loop(
373 avg_loss_val = AverageMeter() 372 avg_loss_val = AverageMeter()
374 avg_acc_val = AverageMeter() 373 avg_acc_val = AverageMeter()
375 374
376 max_acc = 0.0 375 best_acc = 0.0
377 max_acc_val = 0.0 376 best_acc_val = 0.0
377
378 lrs = []
379 losses = []
380 accs = []
378 381
379 local_progress_bar = tqdm( 382 local_progress_bar = tqdm(
380 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 383 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -457,6 +460,8 @@ def train_loop(
457 460
458 accelerator.wait_for_everyone() 461 accelerator.wait_for_everyone()
459 462
463 lrs.append(lr_scheduler.get_last_lr()[0])
464
460 on_after_epoch(lr_scheduler.get_last_lr()[0]) 465 on_after_epoch(lr_scheduler.get_last_lr()[0])
461 466
462 if val_dataloader is not None: 467 if val_dataloader is not None:
@@ -498,18 +503,24 @@ def train_loop(
498 global_progress_bar.clear() 503 global_progress_bar.clear()
499 504
500 if accelerator.is_main_process: 505 if accelerator.is_main_process:
501 if avg_acc_val.avg.item() > max_acc_val: 506 if avg_acc_val.avg.item() > best_acc_val:
502 accelerator.print( 507 accelerator.print(
503 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 508 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
504 on_checkpoint(global_step + global_step_offset, "milestone") 509 on_checkpoint(global_step + global_step_offset, "milestone")
505 max_acc_val = avg_acc_val.avg.item() 510 best_acc_val = avg_acc_val.avg.item()
511
512 losses.append(avg_loss_val.avg.item())
513 accs.append(avg_acc_val.avg.item())
506 else: 514 else:
507 if accelerator.is_main_process: 515 if accelerator.is_main_process:
508 if avg_acc.avg.item() > max_acc: 516 if avg_acc.avg.item() > best_acc:
509 accelerator.print( 517 accelerator.print(
510 f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") 518 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
511 on_checkpoint(global_step + global_step_offset, "milestone") 519 on_checkpoint(global_step + global_step_offset, "milestone")
512 max_acc = avg_acc.avg.item() 520 best_acc = avg_acc.avg.item()
521
522 losses.append(avg_loss.avg.item())
523 accs.append(avg_acc.avg.item())
513 524
514 # Create the pipeline using using the trained modules and save it. 525 # Create the pipeline using using the trained modules and save it.
515 if accelerator.is_main_process: 526 if accelerator.is_main_process:
@@ -523,6 +534,8 @@ def train_loop(
523 on_checkpoint(global_step + global_step_offset, "end") 534 on_checkpoint(global_step + global_step_offset, "end")
524 raise KeyboardInterrupt 535 raise KeyboardInterrupt
525 536
537 return lrs, losses, accs
538
526 539
527def train( 540def train(
528 accelerator: Accelerator, 541 accelerator: Accelerator,
@@ -582,7 +595,7 @@ def train(
582 if accelerator.is_main_process: 595 if accelerator.is_main_process:
583 accelerator.init_trackers(project) 596 accelerator.init_trackers(project)
584 597
585 train_loop( 598 metrics = train_loop(
586 accelerator=accelerator, 599 accelerator=accelerator,
587 optimizer=optimizer, 600 optimizer=optimizer,
588 lr_scheduler=lr_scheduler, 601 lr_scheduler=lr_scheduler,
@@ -598,3 +611,5 @@ def train(
598 611
599 accelerator.end_training() 612 accelerator.end_training()
600 accelerator.free_memory() 613 accelerator.free_memory()
614
615 return metrics
diff --git a/training/lr.py b/training/lr.py
index 9690738..f5b362f 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,238 +1,36 @@
1import math 1from pathlib import Path
2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union
4from functools import partial
5 2
6import matplotlib.pyplot as plt 3import matplotlib.pyplot as plt
7import numpy as np
8import torch
9from torch.optim.lr_scheduler import LambdaLR
10from tqdm.auto import tqdm
11 4
12from training.functional import TrainingCallbacks
13from training.util import AverageMeter
14 5
15 6def plot_metrics(
16def noop(*args, **kwards): 7 metrics: tuple[list[float], list[float], list[float]],
17 pass 8 output_file: Path,
18 9 skip_start: int = 10,
19 10 skip_end: int = 5,
20def noop_ctx(*args, **kwards): 11):
21 return nullcontext() 12 lrs, losses, accs = metrics
22 13
23 14 if skip_end == 0:
24class LRFinder(): 15 lrs = lrs[skip_start:]
25 def __init__( 16 losses = losses[skip_start:]
26 self, 17 accs = accs[skip_start:]
27 accelerator, 18 else:
28 optimizer, 19 lrs = lrs[skip_start:-skip_end]
29 train_dataloader, 20 losses = losses[skip_start:-skip_end]
30 val_dataloader, 21 accs = accs[skip_start:-skip_end]
31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 22
32 callbacks: TrainingCallbacks = TrainingCallbacks() 23 fig, ax_loss = plt.subplots()
33 ): 24 ax_acc = ax_loss.twinx()
34 self.accelerator = accelerator 25
35 self.model = callbacks.on_model() 26 ax_loss.plot(lrs, losses, color='red')
36 self.optimizer = optimizer 27 ax_loss.set_xscale("log")
37 self.train_dataloader = train_dataloader 28 ax_loss.set_xlabel(f"Learning rate")
38 self.val_dataloader = val_dataloader 29 ax_loss.set_ylabel("Loss")
39 self.loss_fn = loss_fn 30
40 self.callbacks = callbacks 31 ax_acc.plot(lrs, accs, color='blue')
41 32 ax_acc.set_xscale("log")
42 # self.model_state = copy.deepcopy(model.state_dict()) 33 ax_acc.set_ylabel("Accuracy")
43 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 34
44 35 plt.savefig(output_file, dpi=300)
45 def run( 36 plt.close()
46 self,
47 end_lr,
48 skip_start: int = 10,
49 skip_end: int = 5,
50 num_epochs: int = 100,
51 num_train_batches: int = math.inf,
52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05,
54 ):
55 best_loss = None
56 best_acc = None
57
58 lrs = []
59 losses = []
60 accs = []
61
62 lr_scheduler = get_exponential_schedule(
63 self.optimizer,
64 end_lr,
65 num_epochs * min(num_train_batches, len(self.train_dataloader))
66 )
67
68 steps = min(num_train_batches, len(self.train_dataloader))
69 steps += min(num_val_batches, len(self.val_dataloader))
70 steps *= num_epochs
71
72 progress_bar = tqdm(
73 range(steps),
74 disable=not self.accelerator.is_local_main_process,
75 dynamic_ncols=True
76 )
77 progress_bar.set_description("Epoch X / Y")
78
79 self.callbacks.on_prepare()
80
81 on_train = self.callbacks.on_train
82 on_before_optimize = self.callbacks.on_before_optimize
83 on_after_optimize = self.callbacks.on_after_optimize
84 on_eval = self.callbacks.on_eval
85
86 for epoch in range(num_epochs):
87 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
88
89 avg_loss = AverageMeter()
90 avg_acc = AverageMeter()
91
92 self.model.train()
93
94 with on_train(epoch):
95 for step, batch in enumerate(self.train_dataloader):
96 if step >= num_train_batches:
97 break
98
99 with self.accelerator.accumulate(self.model):
100 loss, acc, bsz = self.loss_fn(step, batch)
101
102 self.accelerator.backward(loss)
103
104 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
105
106 self.optimizer.step()
107 lr_scheduler.step()
108 self.optimizer.zero_grad(set_to_none=True)
109
110 if self.accelerator.sync_gradients:
111 on_after_optimize(lr_scheduler.get_last_lr()[0])
112
113 progress_bar.update(1)
114
115 self.model.eval()
116
117 with torch.inference_mode():
118 with on_eval():
119 for step, batch in enumerate(self.val_dataloader):
120 if step >= num_val_batches:
121 break
122
123 loss, acc, bsz = self.loss_fn(step, batch, True)
124 avg_loss.update(loss.detach_(), bsz)
125 avg_acc.update(acc.detach_(), bsz)
126
127 progress_bar.update(1)
128
129 loss = avg_loss.avg.item()
130 acc = avg_acc.avg.item()
131
132 if epoch == 0:
133 best_loss = loss
134 best_acc = acc
135 else:
136 if smooth_f > 0:
137 loss = smooth_f * loss + (1 - smooth_f) * losses[-1]
138 acc = smooth_f * acc + (1 - smooth_f) * accs[-1]
139 if loss < best_loss:
140 best_loss = loss
141 if acc > best_acc:
142 best_acc = acc
143
144 lr = lr_scheduler.get_last_lr()[0]
145
146 lrs.append(lr)
147 losses.append(loss)
148 accs.append(acc)
149
150 self.accelerator.log({
151 "loss": loss,
152 "acc": acc,
153 "lr": lr,
154 }, step=epoch)
155
156 progress_bar.set_postfix({
157 "loss": loss,
158 "loss/best": best_loss,
159 "acc": acc,
160 "acc/best": best_acc,
161 "lr": lr,
162 })
163
164 # self.model.load_state_dict(self.model_state)
165 # self.optimizer.load_state_dict(self.optimizer_state)
166
167 if skip_end == 0:
168 lrs = lrs[skip_start:]
169 losses = losses[skip_start:]
170 accs = accs[skip_start:]
171 else:
172 lrs = lrs[skip_start:-skip_end]
173 losses = losses[skip_start:-skip_end]
174 accs = accs[skip_start:-skip_end]
175
176 fig, ax_loss = plt.subplots()
177 ax_acc = ax_loss.twinx()
178
179 ax_loss.plot(lrs, losses, color='red')
180 ax_loss.set_xscale("log")
181 ax_loss.set_xlabel(f"Learning rate")
182 ax_loss.set_ylabel("Loss")
183
184 ax_acc.plot(lrs, accs, color='blue')
185 ax_acc.set_xscale("log")
186 ax_acc.set_ylabel("Accuracy")
187
188 print("LR suggestion: steepest gradient")
189 min_grad_idx = None
190
191 try:
192 min_grad_idx = np.gradient(np.array(losses)).argmin()
193 except ValueError:
194 print(
195 "Failed to compute the gradients, there might not be enough points."
196 )
197
198 try:
199 max_val_idx = np.array(accs).argmax()
200 except ValueError:
201 print(
202 "Failed to compute the gradients, there might not be enough points."
203 )
204
205 if min_grad_idx is not None:
206 print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx]))
207 ax_loss.scatter(
208 lrs[min_grad_idx],
209 losses[min_grad_idx],
210 s=75,
211 marker="o",
212 color="red",
213 zorder=3,
214 label="steepest gradient",
215 )
216 ax_loss.legend()
217
218 if max_val_idx is not None:
219 print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx]))
220 ax_acc.scatter(
221 lrs[max_val_idx],
222 accs[max_val_idx],
223 s=75,
224 marker="o",
225 color="blue",
226 zorder=3,
227 label="maximum",
228 )
229 ax_acc.legend()
230
231
232def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1):
233 def lr_lambda(base_lr: float, current_epoch: int):
234 return (end_lr / base_lr) ** (current_epoch / num_epochs)
235
236 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
237
238 return LambdaLR(optimizer, lr_lambdas, last_epoch)
diff --git a/training/optimization.py b/training/optimization.py
index 6dee4bc..6c9a35d 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -87,6 +87,15 @@ def get_one_cycle_schedule(
87 return LambdaLR(optimizer, lr_lambda, last_epoch) 87 return LambdaLR(optimizer, lr_lambda, last_epoch)
88 88
89 89
90def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1):
91 def lr_lambda(base_lr: float, current_step: int):
92 return (end_lr / base_lr) ** (current_step / num_training_steps)
93
94 lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups]
95
96 return LambdaLR(optimizer, lr_lambdas, last_epoch)
97
98
90def get_scheduler( 99def get_scheduler(
91 id: str, 100 id: str,
92 optimizer: torch.optim.Optimizer, 101 optimizer: torch.optim.Optimizer,
@@ -97,6 +106,7 @@ def get_scheduler(
97 annealing_func: Literal["cos", "half_cos", "linear"] = "cos", 106 annealing_func: Literal["cos", "half_cos", "linear"] = "cos",
98 warmup_exp: int = 1, 107 warmup_exp: int = 1,
99 annealing_exp: int = 1, 108 annealing_exp: int = 1,
109 end_lr: float = 1e3,
100 cycles: int = 1, 110 cycles: int = 1,
101 train_epochs: int = 100, 111 train_epochs: int = 100,
102 warmup_epochs: int = 10, 112 warmup_epochs: int = 10,
@@ -117,6 +127,15 @@ def get_scheduler(
117 annealing_exp=annealing_exp, 127 annealing_exp=annealing_exp,
118 min_lr=min_lr, 128 min_lr=min_lr,
119 ) 129 )
130 elif id == "exponential_growth":
131 if cycles is None:
132 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
133
134 lr_scheduler = get_exponential_growing_schedule(
135 optimizer=optimizer,
136 end_lr=end_lr,
137 num_training_steps=num_training_steps,
138 )
120 elif id == "cosine_with_restarts": 139 elif id == "cosine_with_restarts":
121 if cycles is None: 140 if cycles is None:
122 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 141 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 1277939..e88bf90 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -193,9 +193,7 @@ def dreambooth_prepare(
193 unet: UNet2DConditionModel, 193 unet: UNet2DConditionModel,
194 *args 194 *args
195): 195):
196 prep = [text_encoder, unet] + list(args) 196 return accelerator.prepare(text_encoder, unet, *args)
197 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
198 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 197
200 198
201dreambooth_strategy = TrainingStrategy( 199dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6a76f98..14bdafd 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -176,10 +176,9 @@ def textual_inversion_prepare(
176 elif accelerator.state.mixed_precision == "bf16": 176 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16 177 weight_dtype = torch.bfloat16
178 178
179 prep = [text_encoder] + list(args) 179 prepped = accelerator.prepare(text_encoder, *args)
180 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep)
181 unet.to(accelerator.device, dtype=weight_dtype) 180 unet.to(accelerator.device, dtype=weight_dtype)
182 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 181 return (prepped[0], unet) + prepped[1:]
183 182
184 183
185textual_inversion_strategy = TrainingStrategy( 184textual_inversion_strategy = TrainingStrategy(
diff --git a/training/util.py b/training/util.py
index 237626f..c8524de 100644
--- a/training/util.py
+++ b/training/util.py
@@ -6,6 +6,8 @@ from contextlib import contextmanager
6 6
7import torch 7import torch
8 8
9from diffusers.training_utils import EMAModel as EMAModel_
10
9 11
10def save_args(basepath: Path, args, extra={}): 12def save_args(basepath: Path, args, extra={}):
11 info = {"args": vars(args)} 13 info = {"args": vars(args)}
@@ -30,149 +32,7 @@ class AverageMeter:
30 self.avg = self.sum / self.count 32 self.avg = self.sum / self.count
31 33
32 34
33# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 35class EMAModel(EMAModel_):
34class EMAModel:
35 """
36 Exponential Moving Average of models weights
37 """
38
39 def __init__(
40 self,
41 parameters: Iterable[torch.nn.Parameter],
42 update_after_step: int = 0,
43 inv_gamma: float = 1.0,
44 power: float = 2 / 3,
45 min_value: float = 0.0,
46 max_value: float = 0.9999,
47 ):
48 """
49 @crowsonkb's notes on EMA Warmup:
50 If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
51 to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
52 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
53 at 215.4k steps).
54 Args:
55 inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
56 power (float): Exponential factor of EMA warmup. Default: 2/3.
57 min_value (float): The minimum EMA decay rate. Default: 0.
58 """
59 parameters = list(parameters)
60 self.shadow_params = [p.clone().detach() for p in parameters]
61
62 self.collected_params = None
63
64 self.update_after_step = update_after_step
65 self.inv_gamma = inv_gamma
66 self.power = power
67 self.min_value = min_value
68 self.max_value = max_value
69
70 self.decay = 0.0
71 self.optimization_step = 0
72
73 def get_decay(self, optimization_step: int):
74 """
75 Compute the decay factor for the exponential moving average.
76 """
77 step = max(0, optimization_step - self.update_after_step - 1)
78 value = 1 - (1 + step / self.inv_gamma) ** -self.power
79
80 if step <= 0:
81 return 0.0
82
83 return max(self.min_value, min(value, self.max_value))
84
85 @torch.no_grad()
86 def step(self, parameters):
87 parameters = list(parameters)
88
89 self.optimization_step += 1
90
91 # Compute the decay factor for the exponential moving average.
92 self.decay = self.get_decay(self.optimization_step)
93
94 for s_param, param in zip(self.shadow_params, parameters):
95 if param.requires_grad:
96 s_param.mul_(self.decay)
97 s_param.add_(param.data, alpha=1 - self.decay)
98 else:
99 s_param.copy_(param)
100
101 torch.cuda.empty_cache()
102
103 def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
104 """
105 Copy current averaged parameters into given collection of parameters.
106 Args:
107 parameters: Iterable of `torch.nn.Parameter`; the parameters to be
108 updated with the stored moving averages. If `None`, the
109 parameters with which this `ExponentialMovingAverage` was
110 initialized will be used.
111 """
112 parameters = list(parameters)
113 for s_param, param in zip(self.shadow_params, parameters):
114 param.data.copy_(s_param.data)
115
116 def to(self, device=None, dtype=None) -> None:
117 r"""Move internal buffers of the ExponentialMovingAverage to `device`.
118 Args:
119 device: like `device` argument to `torch.Tensor.to`
120 """
121 # .to() on the tensors handles None correctly
122 self.shadow_params = [
123 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
124 for p in self.shadow_params
125 ]
126
127 def state_dict(self) -> dict:
128 r"""
129 Returns the state of the ExponentialMovingAverage as a dict.
130 This method is used by accelerate during checkpointing to save the ema state dict.
131 """
132 # Following PyTorch conventions, references to tensors are returned:
133 # "returns a reference to the state and not its copy!" -
134 # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
135 return {
136 "decay": self.decay,
137 "optimization_step": self.optimization_step,
138 "shadow_params": self.shadow_params,
139 "collected_params": self.collected_params,
140 }
141
142 def load_state_dict(self, state_dict: dict) -> None:
143 r"""
144 Loads the ExponentialMovingAverage state.
145 This method is used by accelerate during checkpointing to save the ema state dict.
146 Args:
147 state_dict (dict): EMA state. Should be an object returned
148 from a call to :meth:`state_dict`.
149 """
150 # deepcopy, to be consistent with module API
151 state_dict = copy.deepcopy(state_dict)
152
153 self.decay = state_dict["decay"]
154 if self.decay < 0.0 or self.decay > 1.0:
155 raise ValueError("Decay must be between 0 and 1")
156
157 self.optimization_step = state_dict["optimization_step"]
158 if not isinstance(self.optimization_step, int):
159 raise ValueError("Invalid optimization_step")
160
161 self.shadow_params = state_dict["shadow_params"]
162 if not isinstance(self.shadow_params, list):
163 raise ValueError("shadow_params must be a list")
164 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
165 raise ValueError("shadow_params must all be Tensors")
166
167 self.collected_params = state_dict["collected_params"]
168 if self.collected_params is not None:
169 if not isinstance(self.collected_params, list):
170 raise ValueError("collected_params must be a list")
171 if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
172 raise ValueError("collected_params must all be Tensors")
173 if len(self.collected_params) != len(self.shadow_params):
174 raise ValueError("collected_params and shadow_params must have the same length")
175
176 @contextmanager 36 @contextmanager
177 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): 37 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
178 parameters = list(parameters) 38 parameters = list(parameters)