diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-20 14:26:17 +0100 |
| commit | 3575d041f1507811b577fd2c653171fb51c0a386 (patch) | |
| tree | 702f9f1ae4eafc6f8ea06560c4de6bbe1c2acecb /training/util.py | |
| parent | Move Accelerator preparation into strategy (diff) | |
| download | textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.gz textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.tar.bz2 textual-inversion-diff-3575d041f1507811b577fd2c653171fb51c0a386.zip | |
Restored LR finder
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 146 |
1 files changed, 3 insertions, 143 deletions
diff --git a/training/util.py b/training/util.py index 237626f..c8524de 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -6,6 +6,8 @@ from contextlib import contextmanager | |||
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | 8 | ||
| 9 | from diffusers.training_utils import EMAModel as EMAModel_ | ||
| 10 | |||
| 9 | 11 | ||
| 10 | def save_args(basepath: Path, args, extra={}): | 12 | def save_args(basepath: Path, args, extra={}): |
| 11 | info = {"args": vars(args)} | 13 | info = {"args": vars(args)} |
| @@ -30,149 +32,7 @@ class AverageMeter: | |||
| 30 | self.avg = self.sum / self.count | 32 | self.avg = self.sum / self.count |
| 31 | 33 | ||
| 32 | 34 | ||
| 33 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 35 | class EMAModel(EMAModel_): |
| 34 | class EMAModel: | ||
| 35 | """ | ||
| 36 | Exponential Moving Average of models weights | ||
| 37 | """ | ||
| 38 | |||
| 39 | def __init__( | ||
| 40 | self, | ||
| 41 | parameters: Iterable[torch.nn.Parameter], | ||
| 42 | update_after_step: int = 0, | ||
| 43 | inv_gamma: float = 1.0, | ||
| 44 | power: float = 2 / 3, | ||
| 45 | min_value: float = 0.0, | ||
| 46 | max_value: float = 0.9999, | ||
| 47 | ): | ||
| 48 | """ | ||
| 49 | @crowsonkb's notes on EMA Warmup: | ||
| 50 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan | ||
| 51 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), | ||
| 52 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 | ||
| 53 | at 215.4k steps). | ||
| 54 | Args: | ||
| 55 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. | ||
| 56 | power (float): Exponential factor of EMA warmup. Default: 2/3. | ||
| 57 | min_value (float): The minimum EMA decay rate. Default: 0. | ||
| 58 | """ | ||
| 59 | parameters = list(parameters) | ||
| 60 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
| 61 | |||
| 62 | self.collected_params = None | ||
| 63 | |||
| 64 | self.update_after_step = update_after_step | ||
| 65 | self.inv_gamma = inv_gamma | ||
| 66 | self.power = power | ||
| 67 | self.min_value = min_value | ||
| 68 | self.max_value = max_value | ||
| 69 | |||
| 70 | self.decay = 0.0 | ||
| 71 | self.optimization_step = 0 | ||
| 72 | |||
| 73 | def get_decay(self, optimization_step: int): | ||
| 74 | """ | ||
| 75 | Compute the decay factor for the exponential moving average. | ||
| 76 | """ | ||
| 77 | step = max(0, optimization_step - self.update_after_step - 1) | ||
| 78 | value = 1 - (1 + step / self.inv_gamma) ** -self.power | ||
| 79 | |||
| 80 | if step <= 0: | ||
| 81 | return 0.0 | ||
| 82 | |||
| 83 | return max(self.min_value, min(value, self.max_value)) | ||
| 84 | |||
| 85 | @torch.no_grad() | ||
| 86 | def step(self, parameters): | ||
| 87 | parameters = list(parameters) | ||
| 88 | |||
| 89 | self.optimization_step += 1 | ||
| 90 | |||
| 91 | # Compute the decay factor for the exponential moving average. | ||
| 92 | self.decay = self.get_decay(self.optimization_step) | ||
| 93 | |||
| 94 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 95 | if param.requires_grad: | ||
| 96 | s_param.mul_(self.decay) | ||
| 97 | s_param.add_(param.data, alpha=1 - self.decay) | ||
| 98 | else: | ||
| 99 | s_param.copy_(param) | ||
| 100 | |||
| 101 | torch.cuda.empty_cache() | ||
| 102 | |||
| 103 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
| 104 | """ | ||
| 105 | Copy current averaged parameters into given collection of parameters. | ||
| 106 | Args: | ||
| 107 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
| 108 | updated with the stored moving averages. If `None`, the | ||
| 109 | parameters with which this `ExponentialMovingAverage` was | ||
| 110 | initialized will be used. | ||
| 111 | """ | ||
| 112 | parameters = list(parameters) | ||
| 113 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 114 | param.data.copy_(s_param.data) | ||
| 115 | |||
| 116 | def to(self, device=None, dtype=None) -> None: | ||
| 117 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
| 118 | Args: | ||
| 119 | device: like `device` argument to `torch.Tensor.to` | ||
| 120 | """ | ||
| 121 | # .to() on the tensors handles None correctly | ||
| 122 | self.shadow_params = [ | ||
| 123 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
| 124 | for p in self.shadow_params | ||
| 125 | ] | ||
| 126 | |||
| 127 | def state_dict(self) -> dict: | ||
| 128 | r""" | ||
| 129 | Returns the state of the ExponentialMovingAverage as a dict. | ||
| 130 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 131 | """ | ||
| 132 | # Following PyTorch conventions, references to tensors are returned: | ||
| 133 | # "returns a reference to the state and not its copy!" - | ||
| 134 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
| 135 | return { | ||
| 136 | "decay": self.decay, | ||
| 137 | "optimization_step": self.optimization_step, | ||
| 138 | "shadow_params": self.shadow_params, | ||
| 139 | "collected_params": self.collected_params, | ||
| 140 | } | ||
| 141 | |||
| 142 | def load_state_dict(self, state_dict: dict) -> None: | ||
| 143 | r""" | ||
| 144 | Loads the ExponentialMovingAverage state. | ||
| 145 | This method is used by accelerate during checkpointing to save the ema state dict. | ||
| 146 | Args: | ||
| 147 | state_dict (dict): EMA state. Should be an object returned | ||
| 148 | from a call to :meth:`state_dict`. | ||
| 149 | """ | ||
| 150 | # deepcopy, to be consistent with module API | ||
| 151 | state_dict = copy.deepcopy(state_dict) | ||
| 152 | |||
| 153 | self.decay = state_dict["decay"] | ||
| 154 | if self.decay < 0.0 or self.decay > 1.0: | ||
| 155 | raise ValueError("Decay must be between 0 and 1") | ||
| 156 | |||
| 157 | self.optimization_step = state_dict["optimization_step"] | ||
| 158 | if not isinstance(self.optimization_step, int): | ||
| 159 | raise ValueError("Invalid optimization_step") | ||
| 160 | |||
| 161 | self.shadow_params = state_dict["shadow_params"] | ||
| 162 | if not isinstance(self.shadow_params, list): | ||
| 163 | raise ValueError("shadow_params must be a list") | ||
| 164 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
| 165 | raise ValueError("shadow_params must all be Tensors") | ||
| 166 | |||
| 167 | self.collected_params = state_dict["collected_params"] | ||
| 168 | if self.collected_params is not None: | ||
| 169 | if not isinstance(self.collected_params, list): | ||
| 170 | raise ValueError("collected_params must be a list") | ||
| 171 | if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
| 172 | raise ValueError("collected_params must all be Tensors") | ||
| 173 | if len(self.collected_params) != len(self.shadow_params): | ||
| 174 | raise ValueError("collected_params and shadow_params must have the same length") | ||
| 175 | |||
| 176 | @contextmanager | 36 | @contextmanager |
| 177 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 37 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
| 178 | parameters = list(parameters) | 38 | parameters = list(parameters) |
