summaryrefslogtreecommitdiffstats
path: root/training/sampler.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/sampler.py')
-rw-r--r--training/sampler.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/sampler.py b/training/sampler.py
index bdb3e90..0487d66 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler):
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64) 136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 137 weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
138 weights /= np.sum(weights) 138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob 139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights) 140 weights += self.uniform_prob / len(weights)