From 7505f7e843dc719622a15f4ee301609813763d77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 23:50:24 +0100 Subject: Code simplifications, avoid autocast --- train_dreambooth.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -389,6 +389,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, + weight_dtype, datamodule, accelerator, vae, @@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): sample_batch_size=sample_batch_size ) + self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet @@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype + + unet.to(dtype=self.weight_dtype) + text_encoder.to(dtype=self.weight_dtype) + pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, @@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) + del unet del text_encoder del pipeline @@ -798,6 +809,7 @@ def main(): max_acc_val = 0.0 checkpointer = Checkpointer( + weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, -- cgit v1.2.3-54-g00ecf