summaryrefslogtreecommitdiffstats
path: root/training/sampler.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/sampler.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'training/sampler.py')
-rw-r--r--training/sampler.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/sampler.py b/training/sampler.py
index bdb3e90..0487d66 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler):
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64) 136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 137 weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
138 weights /= np.sum(weights) 138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob 139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights) 140 weights += self.uniform_prob / len(weights)