diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 10:02:30 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 10:02:30 +0100 |
| commit | a2c240c8c55dfe930657f66372975d6f26feb168 (patch) | |
| tree | 61c22b830098a6a28f885d9a0964b02a7f429e30 /training/util.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.gz textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.bz2 textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.zip | |
TI: Prepare UNet with Accelerate as well
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -134,11 +134,11 @@ class EMAModel: | |||
| 134 | def __init__( | 134 | def __init__( |
| 135 | self, | 135 | self, |
| 136 | parameters: Iterable[torch.nn.Parameter], | 136 | parameters: Iterable[torch.nn.Parameter], |
| 137 | update_after_step=0, | 137 | update_after_step: int = 0, |
| 138 | inv_gamma=1.0, | 138 | inv_gamma: float = 1.0, |
| 139 | power=2 / 3, | 139 | power: float = 2 / 3, |
| 140 | min_value=0.0, | 140 | min_value: float = 0.0, |
| 141 | max_value=0.9999, | 141 | max_value: float = 0.9999, |
| 142 | ): | 142 | ): |
| 143 | """ | 143 | """ |
| 144 | @crowsonkb's notes on EMA Warmup: | 144 | @crowsonkb's notes on EMA Warmup: |
| @@ -165,7 +165,7 @@ class EMAModel: | |||
| 165 | self.decay = 0.0 | 165 | self.decay = 0.0 |
| 166 | self.optimization_step = 0 | 166 | self.optimization_step = 0 |
| 167 | 167 | ||
| 168 | def get_decay(self, optimization_step): | 168 | def get_decay(self, optimization_step: int): |
| 169 | """ | 169 | """ |
| 170 | Compute the decay factor for the exponential moving average. | 170 | Compute the decay factor for the exponential moving average. |
| 171 | """ | 171 | """ |
| @@ -276,5 +276,5 @@ class EMAModel: | |||
| 276 | self.copy_to(parameters) | 276 | self.copy_to(parameters) |
| 277 | yield | 277 | yield |
| 278 | finally: | 278 | finally: |
| 279 | for s_param, param in zip(original_params, parameters): | 279 | for o_param, param in zip(original_params, parameters): |
| 280 | param.data.copy_(s_param.data) | 280 | param.data.copy_(o_param.data) |
