diff options
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 146 |
1 files changed, 3 insertions, 143 deletions
diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -6,6 +6,8 @@ from contextlib import contextmanager | |||
6 | 6 | ||
7 | import torch | 7 | import torch |
8 | 8 | ||
9 | from diffusers.training_utils import EMAModel as EMAModel_ | ||
10 | |||
9 | 11 | ||
10 | def save_args(basepath: Path, args, extra={}): | 12 | def save_args(basepath: Path, args, extra={}): |
11 | info = {"args": vars(args)} | 13 | info = {"args": vars(args)} |
@@ -30,149 +32,7 @@ class AverageMeter: | |||
30 | self.avg = self.sum / self.count | 32 | self.avg = self.sum / self.count |
31 | 33 | ||
32 | 34 | ||
33 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 35 | class EMAModel(EMAModel_): |
34 | class EMAModel: | ||
35 | """ | ||
36 | Exponential Moving Average of models weights | ||
37 | """ | ||
38 | |||
39 | def __init__( | ||
40 | self, | ||
41 | parameters: Iterable[torch.nn.Parameter], | ||
42 | update_after_step: int = 0, | ||
43 | inv_gamma: float = 1.0, | ||
44 | power: float = 2 / 3, | ||
45 | min_value: float = 0.0, | ||
46 | max_value: float = 0.9999, | ||
47 | ): | ||
48 | """ | ||
49 | @crowsonkb's notes on EMA Warmup: | ||
50 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
51 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
52 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
53 | at 215.4k steps). | ||
54 | Args: | ||
55 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
56 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
57 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
58 | """ | ||
59 | parameters = list(parameters) | ||
60 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
61 | |||
62 | self.collected_params = None | ||
63 | |||
64 | self.update_after_step = update_after_step | ||
65 | self.inv_gamma = inv_gamma | ||
66 | self.power = power | ||
67 | self.min_value = min_value | ||
68 | self.max_value = max_value | ||
69 | |||
70 | self.decay = 0.0 | ||
71 | self.optimization_step = 0 | ||
72 | |||
73 | def get_decay(self, optimization_step: int): | ||
74 | """ | ||
75 | Compute the decay factor for the exponential moving average. | ||
76 | """ | ||
77 | step = max(0, optimization_step - self.update_after_step - 1) | ||
78 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
79 | |||
80 | if step <= 0: | ||
81 | return 0.0 | ||
82 | |||
83 | return max(self.min_value, min(value, self.max_value)) | ||
84 | |||
85 | @torch.no_grad() | ||
86 | def step(self, parameters): | ||
87 | parameters = list(parameters) | ||
88 | |||
89 | self.optimization_step += 1 | ||
90 | |||
91 | # Compute the decay factor for the exponential moving average. | ||
92 | self.decay = self.get_decay(self.optimization_step) | ||
93 | |||
94 | for s_param, param in zip(self.shadow_params, parameters): | ||
95 | if param.requires_grad: | ||
96 | s_param.mul_(self.decay) | ||
97 | s_param.add_(param.data, alpha=1 - self.decay) | ||
98 | else: | ||
99 | s_param.copy_(param) | ||
100 | |||
101 | torch.cuda.empty_cache() | ||
102 | |||
103 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
104 | """ | ||
105 | Copy current averaged parameters into given collection of parameters. | ||
106 | Args: | ||
107 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
108 | updated with the stored moving averages. If `None`, the | ||
109 | parameters with which this `ExponentialMovingAverage` was | ||
110 | initialized will be used. | ||
111 | """ | ||
112 | parameters = list(parameters) | ||
113 | for s_param, param in zip(self.shadow_params, parameters): | ||
114 | param.data.copy_(s_param.data) | ||
115 | |||
116 | def to(self, device=None, dtype=None) -> None: | ||
117 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
118 | Args: | ||
119 | device: like `device` argument to `torch.Tensor.to` | ||
120 | """ | ||
121 | # .to() on the tensors handles None correctly | ||
122 | self.shadow_params = [ | ||
123 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
124 | for p in self.shadow_params | ||
125 | ] | ||
126 | |||
127 | def state_dict(self) -> dict: | ||
128 | r""" | ||
129 | Returns the state of the ExponentialMovingAverage as a dict. | ||
130 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
131 | """ | ||
132 | # Following PyTorch conventions, references to tensors are returned: | ||
133 | # "returns a reference to the state and not its copy!" - | ||
134 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
135 | return { | ||
136 | "decay": self.decay, | ||
137 | "optimization_step": self.optimization_step, | ||
138 | "shadow_params": self.shadow_params, | ||
139 | "collected_params": self.collected_params, | ||
140 | } | ||
141 | |||
142 | def load_state_dict(self, state_dict: dict) -> None: | ||
143 | r""" | ||
144 | Loads the ExponentialMovingAverage state. | ||
145 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
146 | Args: | ||
147 | state_dict (dict): EMA state. Should be an object returned | ||
148 | from a call to :meth:`state_dict`. | ||
149 | """ | ||
150 | # deepcopy, to be consistent with module API | ||
151 | state_dict = copy.deepcopy(state_dict) | ||
152 | |||
153 | self.decay = state_dict["decay"] | ||
154 | if self.decay < 0.0 or self.decay > 1.0: | ||
155 | raise ValueError("Decay must be between 0 and 1") | ||
156 | |||
157 | self.optimization_step = state_dict["optimization_step"] | ||
158 | if not isinstance(self.optimization_step, int): | ||
159 | raise ValueError("Invalid optimization_step") | ||
160 | |||
161 | self.shadow_params = state_dict["shadow_params"] | ||
162 | if not isinstance(self.shadow_params, list): | ||
163 | raise ValueError("shadow_params must be a list") | ||
164 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
165 | raise ValueError("shadow_params must all be Tensors") | ||
166 | |||
167 | self.collected_params = state_dict["collected_params"] | ||
168 | if self.collected_params is not None: | ||
169 | if not isinstance(self.collected_params, list): | ||
170 | raise ValueError("collected_params must be a list") | ||
171 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
172 | raise ValueError("collected_params must all be Tensors") | ||
173 | if len(self.collected_params) != len(self.shadow_params): | ||
174 | raise ValueError("collected_params and shadow_params must have the same length") | ||
175 | |||
176 | @contextmanager | 36 | @contextmanager |
177 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 37 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
178 | parameters = list(parameters) | 38 | parameters = list(parameters) |