From 7505f7e843dc719622a15f4ee301609813763d77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 23:50:24 +0100 Subject: Code simplifications, avoid autocast --- training/util.py | 106 +++++++++++++++++++++++++------------------------------ 1 file changed, 49 insertions(+), 57 deletions(-) (limited to 'training') diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py @@ -60,7 +60,7 @@ class CheckpointerBase: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - @torch.no_grad() + @torch.inference_mode() def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") @@ -68,65 +68,57 @@ class CheckpointerBase: val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), - device=pipeline.device, - generator=generator, - ) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - with torch.autocast("cuda"), torch.inference_mode(): - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] - prompts = [ - prompt - for batch in batches - for prompt in batch["prompts"] - ] - nprompts = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ] - - for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - image=latents[:len(prompt)] if latents is not None else None, - generator=generator if latents is not None else None, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, grid_rows, grid_cols) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid + for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + + for i in range(self.sample_batches): + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + generator=gen, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + ).images + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid.save(file_path, quality=85) + + del all_samples + del image_grid del generator - del stable_latents -- cgit v1.2.3-70-g09d2