From 7505f7e843dc719622a15f4ee301609813763d77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 23:50:24 +0100 Subject: Code simplifications, avoid autocast --- infer.py | 48 ++++++++++++------------ train_dreambooth.py | 12 ++++++ train_ti.py | 8 +++- training/util.py | 106 ++++++++++++++++++++++++---------------------------- 4 files changed, 92 insertions(+), 82 deletions(-) diff --git a/infer.py b/infer.py index 420cb83..f566114 100644 --- a/infer.py +++ b/infer.py @@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): return pipeline +@torch.inference_mode() def generate(output_dir, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] @@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): elif args.scheduler == "kdpm2_a": pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) - with torch.autocast("cuda"), torch.inference_mode(): - for i in range(args.batch_num): - pipeline.set_progress_bar_config( - desc=f"Batch {i + 1} of {args.batch_num}", - dynamic_ncols=True - ) - - generator = torch.Generator(device="cuda").manual_seed(args.seed + i) - images = pipeline( - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_images_per_prompt=args.batch_size, - num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, - generator=generator, - image=init_image, - strength=args.image_noise, - ).images - - for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) + for i in range(args.batch_num): + pipeline.set_progress_bar_config( + desc=f"Batch {i + 1} of {args.batch_num}", + dynamic_ncols=True + ) + + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + images = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_images_per_prompt=args.batch_size, + num_inference_steps=args.steps, + guidance_scale=args.guidance_scale, + generator=generator, + image=init_image, + strength=args.image_noise, + ).images + + for j, image in enumerate(images): + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -389,6 +389,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, + weight_dtype, datamodule, accelerator, vae, @@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): sample_batch_size=sample_batch_size ) + self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet @@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype + + unet.to(dtype=self.weight_dtype) + text_encoder.to(dtype=self.weight_dtype) + pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, @@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) + del unet del text_encoder del pipeline @@ -798,6 +809,7 @@ def main(): max_acc_val = 0.0 checkpointer = Checkpointer( + weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py @@ -361,6 +361,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, + weight_dtype, datamodule, accelerator, vae, @@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): sample_batch_size=sample_batch_size ) + self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet @@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) - # Save a sample image pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, @@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + text_encoder.to(dtype=orig_dtype) + del text_encoder del pipeline @@ -739,6 +744,7 @@ def main(): max_acc_val = 0.0 checkpointer = Checkpointer( + weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py @@ -60,7 +60,7 @@ class CheckpointerBase: self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - @torch.no_grad() + @torch.inference_mode() def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") @@ -68,65 +68,57 @@ class CheckpointerBase: val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), - device=pipeline.device, - generator=generator, - ) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - with torch.autocast("cuda"), torch.inference_mode(): - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.jpg") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(data) - - batches = [ - batch - for j, batch in data_enum - if j * data.batch_size < self.sample_batch_size * self.sample_batches - ] - prompts = [ - prompt - for batch in batches - for prompt in batch["prompts"] - ] - nprompts = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ] - - for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=self.sample_image_size, - width=self.sample_image_size, - image=latents[:len(prompt)] if latents is not None else None, - generator=generator if latents is not None else None, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - ).images - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, grid_rows, grid_cols) - image_grid.save(file_path, quality=85) - - del all_samples - del image_grid + for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") + file_path.parent.mkdir(parents=True, exist_ok=True) + + data_enum = enumerate(data) + + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + + for i in range(self.sample_batches): + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=self.sample_image_size, + width=self.sample_image_size, + generator=gen, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + ).images + + all_samples += samples + + del samples + + image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid.save(file_path, quality=85) + + del all_samples + del image_grid del generator - del stable_latents -- cgit v1.2.3-54-g00ecf