From 7505f7e843dc719622a15f4ee301609813763d77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 23:50:24 +0100 Subject: Code simplifications, avoid autocast --- train_ti.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py @@ -361,6 +361,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, + weight_dtype, datamodule, accelerator, vae, @@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): sample_batch_size=sample_batch_size ) + self.weight_dtype = weight_dtype self.accelerator = accelerator self.vae = vae self.unet = unet @@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): @torch.no_grad() def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): text_encoder = self.accelerator.unwrap_model(self.text_encoder) + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) - # Save a sample image pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=self.vae, @@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + text_encoder.to(dtype=orig_dtype) + del text_encoder del pipeline @@ -739,6 +744,7 @@ def main(): max_acc_val = 0.0 checkpointer = Checkpointer( + weight_dtype=weight_dtype, datamodule=datamodule, accelerator=accelerator, vae=vae, -- cgit v1.2.3-54-g00ecf