diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 4 | ||||
-rw-r--r-- | training/sampler.py | 2 |
2 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 10560e5..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -710,8 +710,8 @@ def train( | |||
710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
711 | 711 | ||
712 | if compile_unet: | 712 | if compile_unet: |
713 | # unet = torch.compile(unet, backend='hidet') | 713 | unet = torch.compile(unet, backend='hidet') |
714 | unet = torch.compile(unet, mode="reduce-overhead") | 714 | # unet = torch.compile(unet, mode="reduce-overhead") |
715 | 715 | ||
716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
717 | accelerator=accelerator, | 717 | accelerator=accelerator, |
diff --git a/training/sampler.py b/training/sampler.py index 8afe255..bdb3e90 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
@@ -129,7 +129,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
129 | self._loss_history = np.zeros( | 129 | self._loss_history = np.zeros( |
130 | [self.num_timesteps, history_per_term], dtype=np.float64 | 130 | [self.num_timesteps, history_per_term], dtype=np.float64 |
131 | ) | 131 | ) |
132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) | 132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=int) |
133 | 133 | ||
134 | def weights(self): | 134 | def weights(self): |
135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |