summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--main.py49
1 files changed, 30 insertions, 19 deletions
diff --git a/main.py b/main.py
index 9bf65a5..aa5af72 100644
--- a/main.py
+++ b/main.py
@@ -351,16 +351,18 @@ class Checkpointer:
351 # Generate and save stable samples 351 # Generate and save stable samples
352 for i in range(0, self.stable_sample_batches): 352 for i in range(0, self.stable_sample_batches):
353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
354 samples = pipeline( 354
355 prompt=prompt, 355 with self.accelerator.autocast():
356 height=self.sample_image_size, 356 samples = pipeline(
357 latents=stable_latents, 357 prompt=prompt,
358 width=self.sample_image_size, 358 height=self.sample_image_size,
359 guidance_scale=guidance_scale, 359 latents=stable_latents,
360 eta=eta, 360 width=self.sample_image_size,
361 num_inference_steps=num_inference_steps, 361 guidance_scale=guidance_scale,
362 output_type='pil' 362 eta=eta,
363 )["sample"] 363 num_inference_steps=num_inference_steps,
364 output_type='pil'
365 )["sample"]
364 366
365 all_samples += samples 367 all_samples += samples
366 del samples 368 del samples
@@ -378,15 +380,17 @@ class Checkpointer:
378 # Generate and save random samples 380 # Generate and save random samples
379 for i in range(0, self.random_sample_batches): 381 for i in range(0, self.random_sample_batches):
380 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 382 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
381 samples = pipeline( 383
382 prompt=prompt, 384 with self.accelerator.autocast():
383 height=self.sample_image_size, 385 samples = pipeline(
384 width=self.sample_image_size, 386 prompt=prompt,
385 guidance_scale=guidance_scale, 387 height=self.sample_image_size,
386 eta=eta, 388 width=self.sample_image_size,
387 num_inference_steps=num_inference_steps, 389 guidance_scale=guidance_scale,
388 output_type='pil' 390 eta=eta,
389 )["sample"] 391 num_inference_steps=num_inference_steps,
392 output_type='pil'
393 )["sample"]
390 394
391 all_samples += samples 395 all_samples += samples
392 del samples 396 del samples
@@ -682,9 +686,13 @@ def main():
682 if accelerator.sync_gradients: 686 if accelerator.sync_gradients:
683 progress_bar.update(1) 687 progress_bar.update(1)
684 local_progress_bar.update(1) 688 local_progress_bar.update(1)
689
685 global_step += 1 690 global_step += 1
686 691
687 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 692 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
693 progress_bar.clear()
694 local_progress_bar.clear()
695
688 checkpointer.checkpoint(global_step + global_step_offset, text_encoder) 696 checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
689 save_resume_file(basepath, args, { 697 save_resume_file(basepath, args, {
690 "global_step": global_step + global_step_offset, 698 "global_step": global_step + global_step_offset,
@@ -740,6 +748,9 @@ def main():
740 748
741 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 749 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
742 750
751 progress_bar.clear()
752 local_progress_bar.clear()
753
743 if min_val_loss > val_loss: 754 if min_val_loss > val_loss:
744 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 755 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
745 min_val_loss = val_loss 756 min_val_loss = val_loss