summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/training/util.py b/training/util.py
index 60d64f0..0ec2032 100644
--- a/training/util.py
+++ b/training/util.py
@@ -55,8 +55,19 @@ class CheckpointerBase:
55 self.sample_batches = sample_batches 55 self.sample_batches = sample_batches
56 self.sample_batch_size = sample_batch_size 56 self.sample_batch_size = sample_batch_size
57 57
58 @torch.no_grad()
59 def checkpoint(self, step: int, postfix: str):
60 pass
61
58 @torch.inference_mode() 62 @torch.inference_mode()
59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 63 def save_samples(
64 self,
65 pipeline,
66 step: int,
67 num_inference_steps: int,
68 guidance_scale: float = 7.5,
69 eta: float = 0.0
70 ):
60 samples_path = Path(self.output_dir).joinpath("samples") 71 samples_path = Path(self.output_dir).joinpath("samples")
61 72
62 train_data = self.datamodule.train_dataloader 73 train_data = self.datamodule.train_dataloader