From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- training/util.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py @@ -44,32 +44,29 @@ class CheckpointerBase: train_dataloader, val_dataloader, output_dir: Path, - sample_image_size: int, - sample_batches: int, - sample_batch_size: int, + sample_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: int = 768, + sample_batches: int = 1, + sample_batch_size: int = 1, seed: Optional[int] = None ): self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed if seed is not None else torch.random.seed() + self.sample_steps = sample_steps + self.sample_guidance_scale = sample_guidance_scale self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + self.seed = seed if seed is not None else torch.random.seed() @torch.no_grad() def checkpoint(self, step: int, postfix: str): pass @torch.inference_mode() - def save_samples( - self, - pipeline, - step: int, - num_inference_steps: int, - guidance_scale: float = 7.5, - eta: float = 0.0 - ): + def save_samples(self, pipeline, step: int): samples_path = Path(self.output_dir).joinpath("samples") generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) @@ -110,9 +107,8 @@ class CheckpointerBase: height=self.sample_image_size, width=self.sample_image_size, generator=gen, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, + guidance_scale=self.sample_guidance_scale, + num_inference_steps=self.sample_steps, output_type='pil' ).images -- cgit v1.2.3-54-g00ecf