summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 10:02:30 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 10:02:30 +0100
commita2c240c8c55dfe930657f66372975d6f26feb168 (patch)
tree61c22b830098a6a28f885d9a0964b02a7f429e30 /training/util.py
parentFix (diff)
downloadtextual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.gz
textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.tar.bz2
textual-inversion-diff-a2c240c8c55dfe930657f66372975d6f26feb168.zip
TI: Prepare UNet with Accelerate as well
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py
index 1008021..781cf04 100644
--- a/training/util.py
+++ b/training/util.py
@@ -134,11 +134,11 @@ class EMAModel:
134 def __init__( 134 def __init__(
135 self, 135 self,
136 parameters: Iterable[torch.nn.Parameter], 136 parameters: Iterable[torch.nn.Parameter],
137 update_after_step=0, 137 update_after_step: int = 0,
138 inv_gamma=1.0, 138 inv_gamma: float = 1.0,
139 power=2 / 3, 139 power: float = 2 / 3,
140 min_value=0.0, 140 min_value: float = 0.0,
141 max_value=0.9999, 141 max_value: float = 0.9999,
142 ): 142 ):
143 """ 143 """
144 @crowsonkb's notes on EMA Warmup: 144 @crowsonkb's notes on EMA Warmup:
@@ -165,7 +165,7 @@ class EMAModel:
165 self.decay = 0.0 165 self.decay = 0.0
166 self.optimization_step = 0 166 self.optimization_step = 0
167 167
168 def get_decay(self, optimization_step): 168 def get_decay(self, optimization_step: int):
169 """ 169 """
170 Compute the decay factor for the exponential moving average. 170 Compute the decay factor for the exponential moving average.
171 """ 171 """
@@ -276,5 +276,5 @@ class EMAModel:
276 self.copy_to(parameters) 276 self.copy_to(parameters)
277 yield 277 yield
278 finally: 278 finally:
279 for s_param, param in zip(original_params, parameters): 279 for o_param, param in zip(original_params, parameters):
280 param.data.copy_(s_param.data) 280 param.data.copy_(o_param.data)