diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 35 | ||||
-rw-r--r-- | training/lr.py | 266 | ||||
-rw-r--r-- | training/optimization.py | 19 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 5 | ||||
-rw-r--r-- | training/util.py | 146 |
6 files changed, 82 insertions, 393 deletions
diff --git a/training/functional.py b/training/functional.py index fb135c4..c373ac9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -7,7 +7,6 @@ from pathlib import Path | |||
7 | import itertools | 7 | import itertools |
8 | 8 | ||
9 | import torch | 9 | import torch |
10 | import torch.nn as nn | ||
11 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
12 | from torch.utils.data import DataLoader | 11 | from torch.utils.data import DataLoader |
13 | 12 | ||
@@ -373,8 +372,12 @@ def train_loop( | |||
373 | avg_loss_val = AverageMeter() | 372 | avg_loss_val = AverageMeter() |
374 | avg_acc_val = AverageMeter() | 373 | avg_acc_val = AverageMeter() |
375 | 374 | ||
376 | max_acc = 0.0 | 375 | best_acc = 0.0 |
377 | max_acc_val = 0.0 | 376 | best_acc_val = 0.0 |
377 | |||
378 | lrs = [] | ||
379 | losses = [] | ||
380 | accs = [] | ||
378 | 381 | ||
379 | local_progress_bar = tqdm( | 382 | local_progress_bar = tqdm( |
380 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 383 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -457,6 +460,8 @@ def train_loop( | |||
457 | 460 | ||
458 | accelerator.wait_for_everyone() | 461 | accelerator.wait_for_everyone() |
459 | 462 | ||
463 | lrs.append(lr_scheduler.get_last_lr()[0]) | ||
464 | |||
460 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 465 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
461 | 466 | ||
462 | if val_dataloader is not None: | 467 | if val_dataloader is not None: |
@@ -498,18 +503,24 @@ def train_loop( | |||
498 | global_progress_bar.clear() | 503 | global_progress_bar.clear() |
499 | 504 | ||
500 | if accelerator.is_main_process: | 505 | if accelerator.is_main_process: |
501 | if avg_acc_val.avg.item() > max_acc_val: | 506 | if avg_acc_val.avg.item() > best_acc_val: |
502 | accelerator.print( | 507 | accelerator.print( |
503 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 508 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
504 | on_checkpoint(global_step + global_step_offset, "milestone") | 509 | on_checkpoint(global_step + global_step_offset, "milestone") |
505 | max_acc_val = avg_acc_val.avg.item() | 510 | best_acc_val = avg_acc_val.avg.item() |
511 | |||
512 | losses.append(avg_loss_val.avg.item()) | ||
513 | accs.append(avg_acc_val.avg.item()) | ||
506 | else: | 514 | else: |
507 | if accelerator.is_main_process: | 515 | if accelerator.is_main_process: |
508 | if avg_acc.avg.item() > max_acc: | 516 | if avg_acc.avg.item() > best_acc: |
509 | accelerator.print( | 517 | accelerator.print( |
510 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | 518 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
511 | on_checkpoint(global_step + global_step_offset, "milestone") | 519 | on_checkpoint(global_step + global_step_offset, "milestone") |
512 | max_acc = avg_acc.avg.item() | 520 | best_acc = avg_acc.avg.item() |
521 | |||
522 | losses.append(avg_loss.avg.item()) | ||
523 | accs.append(avg_acc.avg.item()) | ||
513 | 524 | ||
514 | # Create the pipeline using using the trained modules and save it. | 525 | # Create the pipeline using using the trained modules and save it. |
515 | if accelerator.is_main_process: | 526 | if accelerator.is_main_process: |
@@ -523,6 +534,8 @@ def train_loop( | |||
523 | on_checkpoint(global_step + global_step_offset, "end") | 534 | on_checkpoint(global_step + global_step_offset, "end") |
524 | raise KeyboardInterrupt | 535 | raise KeyboardInterrupt |
525 | 536 | ||
537 | return lrs, losses, accs | ||
538 | |||
526 | 539 | ||
527 | def train( | 540 | def train( |
528 | accelerator: Accelerator, | 541 | accelerator: Accelerator, |
@@ -582,7 +595,7 @@ def train( | |||
582 | if accelerator.is_main_process: | 595 | if accelerator.is_main_process: |
583 | accelerator.init_trackers(project) | 596 | accelerator.init_trackers(project) |
584 | 597 | ||
585 | train_loop( | 598 | metrics = train_loop( |
586 | accelerator=accelerator, | 599 | accelerator=accelerator, |
587 | optimizer=optimizer, | 600 | optimizer=optimizer, |
588 | lr_scheduler=lr_scheduler, | 601 | lr_scheduler=lr_scheduler, |
@@ -598,3 +611,5 @@ def train( | |||
598 | 611 | ||
599 | accelerator.end_training() | 612 | accelerator.end_training() |
600 | accelerator.free_memory() | 613 | accelerator.free_memory() |
614 | |||
615 | return metrics | ||
diff --git a/training/lr.py b/training/lr.py index 9690738..f5b362f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,238 +1,36 @@ | |||
1 | import math | 1 | from pathlib import Path |
2 | from contextlib import _GeneratorContextManager, nullcontext | ||
3 | from typing import Callable, Any, Tuple, Union | ||
4 | from functools import partial | ||
5 | 2 | ||
6 | import matplotlib.pyplot as plt | 3 | import matplotlib.pyplot as plt |
7 | import numpy as np | ||
8 | import torch | ||
9 | from torch.optim.lr_scheduler import LambdaLR | ||
10 | from tqdm.auto import tqdm | ||
11 | 4 | ||
12 | from training.functional import TrainingCallbacks | ||
13 | from training.util import AverageMeter | ||
14 | 5 | ||
15 | 6 | def plot_metrics( | |
16 | def noop(*args, **kwards): | 7 | metrics: tuple[list[float], list[float], list[float]], |
17 | pass | 8 | output_file: Path, |
18 | 9 | skip_start: int = 10, | |
19 | 10 | skip_end: int = 5, | |
20 | def noop_ctx(*args, **kwards): | 11 | ): |
21 | return nullcontext() | 12 | lrs, losses, accs = metrics |
22 | 13 | ||
23 | 14 | if skip_end == 0: | |
24 | class LRFinder(): | 15 | lrs = lrs[skip_start:] |
25 | def __init__( | 16 | losses = losses[skip_start:] |
26 | self, | 17 | accs = accs[skip_start:] |
27 | accelerator, | 18 | else: |
28 | optimizer, | 19 | lrs = lrs[skip_start:-skip_end] |
29 | train_dataloader, | 20 | losses = losses[skip_start:-skip_end] |
30 | val_dataloader, | 21 | accs = accs[skip_start:-skip_end] |
31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 22 | |
32 | callbacks: TrainingCallbacks = TrainingCallbacks() | 23 | fig, ax_loss = plt.subplots() |
33 | ): | 24 | ax_acc = ax_loss.twinx() |
34 | self.accelerator = accelerator | 25 | |
35 | self.model = callbacks.on_model() | 26 | ax_loss.plot(lrs, losses, color='red') |
36 | self.optimizer = optimizer | 27 | ax_loss.set_xscale("log") |
37 | self.train_dataloader = train_dataloader | 28 | ax_loss.set_xlabel(f"Learning rate") |
38 | self.val_dataloader = val_dataloader | 29 | ax_loss.set_ylabel("Loss") |
39 | self.loss_fn = loss_fn | 30 | |
40 | self.callbacks = callbacks | 31 | ax_acc.plot(lrs, accs, color='blue') |
41 | 32 | ax_acc.set_xscale("log") | |
42 | # self.model_state = copy.deepcopy(model.state_dict()) | 33 | ax_acc.set_ylabel("Accuracy") |
43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 34 | |
44 | 35 | plt.savefig(output_file, dpi=300) | |
45 | def run( | 36 | plt.close() |
46 | self, | ||
47 | end_lr, | ||
48 | skip_start: int = 10, | ||
49 | skip_end: int = 5, | ||
50 | num_epochs: int = 100, | ||
51 | num_train_batches: int = math.inf, | ||
52 | num_val_batches: int = math.inf, | ||
53 | smooth_f: float = 0.05, | ||
54 | ): | ||
55 | best_loss = None | ||
56 | best_acc = None | ||
57 | |||
58 | lrs = [] | ||
59 | losses = [] | ||
60 | accs = [] | ||
61 | |||
62 | lr_scheduler = get_exponential_schedule( | ||
63 | self.optimizer, | ||
64 | end_lr, | ||
65 | num_epochs * min(num_train_batches, len(self.train_dataloader)) | ||
66 | ) | ||
67 | |||
68 | steps = min(num_train_batches, len(self.train_dataloader)) | ||
69 | steps += min(num_val_batches, len(self.val_dataloader)) | ||
70 | steps *= num_epochs | ||
71 | |||
72 | progress_bar = tqdm( | ||
73 | range(steps), | ||
74 | disable=not self.accelerator.is_local_main_process, | ||
75 | dynamic_ncols=True | ||
76 | ) | ||
77 | progress_bar.set_description("Epoch X / Y") | ||
78 | |||
79 | self.callbacks.on_prepare() | ||
80 | |||
81 | on_train = self.callbacks.on_train | ||
82 | on_before_optimize = self.callbacks.on_before_optimize | ||
83 | on_after_optimize = self.callbacks.on_after_optimize | ||
84 | on_eval = self.callbacks.on_eval | ||
85 | |||
86 | for epoch in range(num_epochs): | ||
87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
88 | |||
89 | avg_loss = AverageMeter() | ||
90 | avg_acc = AverageMeter() | ||
91 | |||
92 | self.model.train() | ||
93 | |||
94 | with on_train(epoch): | ||
95 | for step, batch in enumerate(self.train_dataloader): | ||
96 | if step >= num_train_batches: | ||
97 | break | ||
98 | |||
99 | with self.accelerator.accumulate(self.model): | ||
100 | loss, acc, bsz = self.loss_fn(step, batch) | ||
101 | |||
102 | self.accelerator.backward(loss) | ||
103 | |||
104 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | ||
105 | |||
106 | self.optimizer.step() | ||
107 | lr_scheduler.step() | ||
108 | self.optimizer.zero_grad(set_to_none=True) | ||
109 | |||
110 | if self.accelerator.sync_gradients: | ||
111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
112 | |||
113 | progress_bar.update(1) | ||
114 | |||
115 | self.model.eval() | ||
116 | |||
117 | with torch.inference_mode(): | ||
118 | with on_eval(): | ||
119 | for step, batch in enumerate(self.val_dataloader): | ||
120 | if step >= num_val_batches: | ||
121 | break | ||
122 | |||
123 | loss, acc, bsz = self.loss_fn(step, batch, True) | ||
124 | avg_loss.update(loss.detach_(), bsz) | ||
125 | avg_acc.update(acc.detach_(), bsz) | ||
126 | |||
127 | progress_bar.update(1) | ||
128 | |||
129 | loss = avg_loss.avg.item() | ||
130 | acc = avg_acc.avg.item() | ||
131 | |||
132 | if epoch == 0: | ||
133 | best_loss = loss | ||
134 | best_acc = acc | ||
135 | else: | ||
136 | if smooth_f > 0: | ||
137 | loss = smooth_f * loss + (1 - smooth_f) * losses[-1] | ||
138 | acc = smooth_f * acc + (1 - smooth_f) * accs[-1] | ||
139 | if loss < best_loss: | ||
140 | best_loss = loss | ||
141 | if acc > best_acc: | ||
142 | best_acc = acc | ||
143 | |||
144 | lr = lr_scheduler.get_last_lr()[0] | ||
145 | |||
146 | lrs.append(lr) | ||
147 | losses.append(loss) | ||
148 | accs.append(acc) | ||
149 | |||
150 | self.accelerator.log({ | ||
151 | "loss": loss, | ||
152 | "acc": acc, | ||
153 | "lr": lr, | ||
154 | }, step=epoch) | ||
155 | |||
156 | progress_bar.set_postfix({ | ||
157 | "loss": loss, | ||
158 | "loss/best": best_loss, | ||
159 | "acc": acc, | ||
160 | "acc/best": best_acc, | ||
161 | "lr": lr, | ||
162 | }) | ||
163 | |||
164 | # self.model.load_state_dict(self.model_state) | ||
165 | # self.optimizer.load_state_dict(self.optimizer_state) | ||
166 | |||
167 | if skip_end == 0: | ||
168 | lrs = lrs[skip_start:] | ||
169 | losses = losses[skip_start:] | ||
170 | accs = accs[skip_start:] | ||
171 | else: | ||
172 | lrs = lrs[skip_start:-skip_end] | ||
173 | losses = losses[skip_start:-skip_end] | ||
174 | accs = accs[skip_start:-skip_end] | ||
175 | |||
176 | fig, ax_loss = plt.subplots() | ||
177 | ax_acc = ax_loss.twinx() | ||
178 | |||
179 | ax_loss.plot(lrs, losses, color='red') | ||
180 | ax_loss.set_xscale("log") | ||
181 | ax_loss.set_xlabel(f"Learning rate") | ||
182 | ax_loss.set_ylabel("Loss") | ||
183 | |||
184 | ax_acc.plot(lrs, accs, color='blue') | ||
185 | ax_acc.set_xscale("log") | ||
186 | ax_acc.set_ylabel("Accuracy") | ||
187 | |||
188 | print("LR suggestion: steepest gradient") | ||
189 | min_grad_idx = None | ||
190 | |||
191 | try: | ||
192 | min_grad_idx = np.gradient(np.array(losses)).argmin() | ||
193 | except ValueError: | ||
194 | print( | ||
195 | "Failed to compute the gradients, there might not be enough points." | ||
196 | ) | ||
197 | |||
198 | try: | ||
199 | max_val_idx = np.array(accs).argmax() | ||
200 | except ValueError: | ||
201 | print( | ||
202 | "Failed to compute the gradients, there might not be enough points." | ||
203 | ) | ||
204 | |||
205 | if min_grad_idx is not None: | ||
206 | print("Suggested LR (loss): {:.2E}".format(lrs[min_grad_idx])) | ||
207 | ax_loss.scatter( | ||
208 | lrs[min_grad_idx], | ||
209 | losses[min_grad_idx], | ||
210 | s=75, | ||
211 | marker="o", | ||
212 | color="red", | ||
213 | zorder=3, | ||
214 | label="steepest gradient", | ||
215 | ) | ||
216 | ax_loss.legend() | ||
217 | |||
218 | if max_val_idx is not None: | ||
219 | print("Suggested LR (acc): {:.2E}".format(lrs[max_val_idx])) | ||
220 | ax_acc.scatter( | ||
221 | lrs[max_val_idx], | ||
222 | accs[max_val_idx], | ||
223 | s=75, | ||
224 | marker="o", | ||
225 | color="blue", | ||
226 | zorder=3, | ||
227 | label="maximum", | ||
228 | ) | ||
229 | ax_acc.legend() | ||
230 | |||
231 | |||
232 | def get_exponential_schedule(optimizer, end_lr: float, num_epochs: int, last_epoch: int = -1): | ||
233 | def lr_lambda(base_lr: float, current_epoch: int): | ||
234 | return (end_lr / base_lr) ** (current_epoch / num_epochs) | ||
235 | |||
236 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
237 | |||
238 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
diff --git a/training/optimization.py b/training/optimization.py index 6dee4bc..6c9a35d 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -87,6 +87,15 @@ def get_one_cycle_schedule( | |||
87 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 87 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
88 | 88 | ||
89 | 89 | ||
90 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | ||
91 | def lr_lambda(base_lr: float, current_step: int): | ||
92 | return (end_lr / base_lr) ** (current_step / num_training_steps) | ||
93 | |||
94 | lr_lambdas = [partial(lr_lambda, group["lr"]) for group in optimizer.param_groups] | ||
95 | |||
96 | return LambdaLR(optimizer, lr_lambdas, last_epoch) | ||
97 | |||
98 | |||
90 | def get_scheduler( | 99 | def get_scheduler( |
91 | id: str, | 100 | id: str, |
92 | optimizer: torch.optim.Optimizer, | 101 | optimizer: torch.optim.Optimizer, |
@@ -97,6 +106,7 @@ def get_scheduler( | |||
97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", | 106 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
98 | warmup_exp: int = 1, | 107 | warmup_exp: int = 1, |
99 | annealing_exp: int = 1, | 108 | annealing_exp: int = 1, |
109 | end_lr: float = 1e3, | ||
100 | cycles: int = 1, | 110 | cycles: int = 1, |
101 | train_epochs: int = 100, | 111 | train_epochs: int = 100, |
102 | warmup_epochs: int = 10, | 112 | warmup_epochs: int = 10, |
@@ -117,6 +127,15 @@ def get_scheduler( | |||
117 | annealing_exp=annealing_exp, | 127 | annealing_exp=annealing_exp, |
118 | min_lr=min_lr, | 128 | min_lr=min_lr, |
119 | ) | 129 | ) |
130 | elif id == "exponential_growth": | ||
131 | if cycles is None: | ||
132 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | ||
133 | |||
134 | lr_scheduler = get_exponential_growing_schedule( | ||
135 | optimizer=optimizer, | ||
136 | end_lr=end_lr, | ||
137 | num_training_steps=num_training_steps, | ||
138 | ) | ||
120 | elif id == "cosine_with_restarts": | 139 | elif id == "cosine_with_restarts": |
121 | if cycles is None: | 140 | if cycles is None: |
122 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 141 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 1277939..e88bf90 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -193,9 +193,7 @@ def dreambooth_prepare( | |||
193 | unet: UNet2DConditionModel, | 193 | unet: UNet2DConditionModel, |
194 | *args | 194 | *args |
195 | ): | 195 | ): |
196 | prep = [text_encoder, unet] + list(args) | 196 | return accelerator.prepare(text_encoder, unet, *args) |
197 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
198 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
199 | 197 | ||
200 | 198 | ||
201 | dreambooth_strategy = TrainingStrategy( | 199 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a76f98..14bdafd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -176,10 +176,9 @@ def textual_inversion_prepare( | |||
176 | elif accelerator.state.mixed_precision == "bf16": | 176 | elif accelerator.state.mixed_precision == "bf16": |
177 | weight_dtype = torch.bfloat16 | 177 | weight_dtype = torch.bfloat16 |
178 | 178 | ||
179 | prep = [text_encoder] + list(args) | 179 | prepped = accelerator.prepare(text_encoder, *args) |
180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
181 | unet.to(accelerator.device, dtype=weight_dtype) | 180 | unet.to(accelerator.device, dtype=weight_dtype) |
182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 181 | return (prepped[0], unet) + prepped[1:] |
183 | 182 | ||
184 | 183 | ||
185 | textual_inversion_strategy = TrainingStrategy( | 184 | textual_inversion_strategy = TrainingStrategy( |
diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -6,6 +6,8 @@ from contextlib import contextmanager | |||
6 | 6 | ||
7 | import torch | 7 | import torch |
8 | 8 | ||
9 | from diffusers.training_utils import EMAModel as EMAModel_ | ||
10 | |||
9 | 11 | ||
10 | def save_args(basepath: Path, args, extra={}): | 12 | def save_args(basepath: Path, args, extra={}): |
11 | info = {"args": vars(args)} | 13 | info = {"args": vars(args)} |
@@ -30,149 +32,7 @@ class AverageMeter: | |||
30 | self.avg = self.sum / self.count | 32 | self.avg = self.sum / self.count |
31 | 33 | ||
32 | 34 | ||
33 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 35 | class EMAModel(EMAModel_): |
34 | class EMAModel: | ||
35 | """ | ||
36 | Exponential Moving Average of models weights | ||
37 | """ | ||
38 | |||
39 | def __init__( | ||
40 | self, | ||
41 | parameters: Iterable[torch.nn.Parameter], | ||
42 | update_after_step: int = 0, | ||
43 | inv_gamma: float = 1.0, | ||
44 | power: float = 2 / 3, | ||
45 | min_value: float = 0.0, | ||
46 | max_value: float = 0.9999, | ||
47 | ): | ||
48 | """ | ||
49 | @crowsonkb's notes on EMA Warmup: | ||
50 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
51 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
52 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
53 | at 215.4k steps). | ||
54 | Args: | ||
55 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
56 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
57 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
58 | """ | ||
59 | parameters = list(parameters) | ||
60 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
61 | |||
62 | self.collected_params = None | ||
63 | |||
64 | self.update_after_step = update_after_step | ||
65 | self.inv_gamma = inv_gamma | ||
66 | self.power = power | ||
67 | self.min_value = min_value | ||
68 | self.max_value = max_value | ||
69 | |||
70 | self.decay = 0.0 | ||
71 | self.optimization_step = 0 | ||
72 | |||
73 | def get_decay(self, optimization_step: int): | ||
74 | """ | ||
75 | Compute the decay factor for the exponential moving average. | ||
76 | """ | ||
77 | step = max(0, optimization_step - self.update_after_step - 1) | ||
78 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
79 | |||
80 | if step <= 0: | ||
81 | return 0.0 | ||
82 | |||
83 | return max(self.min_value, min(value, self.max_value)) | ||
84 | |||
85 | @torch.no_grad() | ||
86 | def step(self, parameters): | ||
87 | parameters = list(parameters) | ||
88 | |||
89 | self.optimization_step += 1 | ||
90 | |||
91 | # Compute the decay factor for the exponential moving average. | ||
92 | self.decay = self.get_decay(self.optimization_step) | ||
93 | |||
94 | for s_param, param in zip(self.shadow_params, parameters): | ||
95 | if param.requires_grad: | ||
96 | s_param.mul_(self.decay) | ||
97 | s_param.add_(param.data, alpha=1 - self.decay) | ||
98 | else: | ||
99 | s_param.copy_(param) | ||
100 | |||
101 | torch.cuda.empty_cache() | ||
102 | |||
103 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
104 | """ | ||
105 | Copy current averaged parameters into given collection of parameters. | ||
106 | Args: | ||
107 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
108 | updated with the stored moving averages. If `None`, the | ||
109 | parameters with which this `ExponentialMovingAverage` was | ||
110 | initialized will be used. | ||
111 | """ | ||
112 | parameters = list(parameters) | ||
113 | for s_param, param in zip(self.shadow_params, parameters): | ||
114 | param.data.copy_(s_param.data) | ||
115 | |||
116 | def to(self, device=None, dtype=None) -> None: | ||
117 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
118 | Args: | ||
119 | device: like `device` argument to `torch.Tensor.to` | ||
120 | """ | ||
121 | # .to() on the tensors handles None correctly | ||
122 | self.shadow_params = [ | ||
123 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
124 | for p in self.shadow_params | ||
125 | ] | ||
126 | |||
127 | def state_dict(self) -> dict: | ||
128 | r""" | ||
129 | Returns the state of the ExponentialMovingAverage as a dict. | ||
130 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
131 | """ | ||
132 | # Following PyTorch conventions, references to tensors are returned: | ||
133 | # "returns a reference to the state and not its copy!" - | ||
134 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
135 | return { | ||
136 | "decay": self.decay, | ||
137 | "optimization_step": self.optimization_step, | ||
138 | "shadow_params": self.shadow_params, | ||
139 | "collected_params": self.collected_params, | ||
140 | } | ||
141 | |||
142 | def load_state_dict(self, state_dict: dict) -> None: | ||
143 | r""" | ||
144 | Loads the ExponentialMovingAverage state. | ||
145 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
146 | Args: | ||
147 | state_dict (dict): EMA state. Should be an object returned | ||
148 | from a call to :meth:`state_dict`. | ||
149 | """ | ||
150 | # deepcopy, to be consistent with module API | ||
151 | state_dict = copy.deepcopy(state_dict) | ||
152 | |||
153 | self.decay = state_dict["decay"] | ||
154 | if self.decay < 0.0 or self.decay > 1.0: | ||
155 | raise ValueError("Decay must be between 0 and 1") | ||
156 | |||
157 | self.optimization_step = state_dict["optimization_step"] | ||
158 | if not isinstance(self.optimization_step, int): | ||
159 | raise ValueError("Invalid optimization_step") | ||
160 | |||
161 | self.shadow_params = state_dict["shadow_params"] | ||
162 | if not isinstance(self.shadow_params, list): | ||
163 | raise ValueError("shadow_params must be a list") | ||
164 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
165 | raise ValueError("shadow_params must all be Tensors") | ||
166 | |||
167 | self.collected_params = state_dict["collected_params"] | ||
168 | if self.collected_params is not None: | ||
169 | if not isinstance(self.collected_params, list): | ||
170 | raise ValueError("collected_params must be a list") | ||
171 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
172 | raise ValueError("collected_params must all be Tensors") | ||
173 | if len(self.collected_params) != len(self.shadow_params): | ||
174 | raise ValueError("collected_params and shadow_params must have the same length") | ||
175 | |||
176 | @contextmanager | 36 | @contextmanager |
177 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 37 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
178 | parameters = list(parameters) | 38 | parameters = list(parameters) |