From 2ab3573380af41a5a70db7efae74728e560c6f0e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 26 Sep 2022 17:55:45 +0200 Subject: Autocast on sample generation, progress bar cleanup --- main.py | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 9bf65a5..aa5af72 100644 --- a/main.py +++ b/main.py @@ -351,16 +351,18 @@ class Checkpointer: # Generate and save stable samples for i in range(0, self.stable_sample_batches): prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - latents=stable_latents, - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] + + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + latents=stable_latents, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] all_samples += samples del samples @@ -378,15 +380,17 @@ class Checkpointer: # Generate and save random samples for i in range(0, self.random_sample_batches): prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] + + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] all_samples += samples del samples @@ -682,9 +686,13 @@ def main(): if accelerator.sync_gradients: progress_bar.update(1) local_progress_bar.update(1) + global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + progress_bar.clear() + local_progress_bar.clear() + checkpointer.checkpoint(global_step + global_step_offset, text_encoder) save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, @@ -740,6 +748,9 @@ def main(): accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + progress_bar.clear() + local_progress_bar.clear() + if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss -- cgit v1.2.3-54-g00ecf