summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/training/util.py b/training/util.py
index bc466e2..6f42228 100644
--- a/training/util.py
+++ b/training/util.py
@@ -58,8 +58,8 @@ class CheckpointerBase:
58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
59 samples_path = Path(self.output_dir).joinpath("samples") 59 samples_path = Path(self.output_dir).joinpath("samples")
60 60
61 train_data = self.datamodule.train_dataloader() 61 train_data = self.datamodule.train_dataloaders[0]
62 val_data = self.datamodule.val_dataloader() 62 val_data = self.datamodule.val_dataloader
63 63
64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
65 65