summaryrefslogtreecommitdiffstats
path: root/training/sampler.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
commit55a12f2c683b2ecfa4fc8b4015462ad2798abda5 (patch)
treefeeb3f9a041466e773bb5921cbf0adb208d60a49 /training/sampler.py
parentAvoid model recompilation due to varying prompt lengths (diff)
downloadtextual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.gz
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.bz2
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.zip
Fix LoRA training with DAdan
Diffstat (limited to 'training/sampler.py')
-rw-r--r--training/sampler.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/sampler.py b/training/sampler.py
index 8afe255..bdb3e90 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -129,7 +129,7 @@ class LossSecondMomentResampler(LossAwareSampler):
129 self._loss_history = np.zeros( 129 self._loss_history = np.zeros(
130 [self.num_timesteps, history_per_term], dtype=np.float64 130 [self.num_timesteps, history_per_term], dtype=np.float64
131 ) 131 )
132 self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) 132 self._loss_counts = np.zeros([self.num_timesteps], dtype=int)
133 133
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():