From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- training/util.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py @@ -55,8 +55,19 @@ class CheckpointerBase: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + @torch.no_grad() + def checkpoint(self, step: int, postfix: str): + pass + @torch.inference_mode() - def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + def save_samples( + self, + pipeline, + step: int, + num_inference_steps: int, + guidance_scale: float = 7.5, + eta: float = 0.0 + ): samples_path = Path(self.output_dir).joinpath("samples") train_data = self.datamodule.train_dataloader -- cgit v1.2.3-54-g00ecf