diff options
-rw-r--r-- | main.py | 49 |
1 files changed, 30 insertions, 19 deletions
@@ -351,16 +351,18 @@ class Checkpointer: | |||
351 | # Generate and save stable samples | 351 | # Generate and save stable samples |
352 | for i in range(0, self.stable_sample_batches): | 352 | for i in range(0, self.stable_sample_batches): |
353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] |
354 | samples = pipeline( | 354 | |
355 | prompt=prompt, | 355 | with self.accelerator.autocast(): |
356 | height=self.sample_image_size, | 356 | samples = pipeline( |
357 | latents=stable_latents, | 357 | prompt=prompt, |
358 | width=self.sample_image_size, | 358 | height=self.sample_image_size, |
359 | guidance_scale=guidance_scale, | 359 | latents=stable_latents, |
360 | eta=eta, | 360 | width=self.sample_image_size, |
361 | num_inference_steps=num_inference_steps, | 361 | guidance_scale=guidance_scale, |
362 | output_type='pil' | 362 | eta=eta, |
363 | )["sample"] | 363 | num_inference_steps=num_inference_steps, |
364 | output_type='pil' | ||
365 | )["sample"] | ||
364 | 366 | ||
365 | all_samples += samples | 367 | all_samples += samples |
366 | del samples | 368 | del samples |
@@ -378,15 +380,17 @@ class Checkpointer: | |||
378 | # Generate and save random samples | 380 | # Generate and save random samples |
379 | for i in range(0, self.random_sample_batches): | 381 | for i in range(0, self.random_sample_batches): |
380 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 382 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] |
381 | samples = pipeline( | 383 | |
382 | prompt=prompt, | 384 | with self.accelerator.autocast(): |
383 | height=self.sample_image_size, | 385 | samples = pipeline( |
384 | width=self.sample_image_size, | 386 | prompt=prompt, |
385 | guidance_scale=guidance_scale, | 387 | height=self.sample_image_size, |
386 | eta=eta, | 388 | width=self.sample_image_size, |
387 | num_inference_steps=num_inference_steps, | 389 | guidance_scale=guidance_scale, |
388 | output_type='pil' | 390 | eta=eta, |
389 | )["sample"] | 391 | num_inference_steps=num_inference_steps, |
392 | output_type='pil' | ||
393 | )["sample"] | ||
390 | 394 | ||
391 | all_samples += samples | 395 | all_samples += samples |
392 | del samples | 396 | del samples |
@@ -682,9 +686,13 @@ def main(): | |||
682 | if accelerator.sync_gradients: | 686 | if accelerator.sync_gradients: |
683 | progress_bar.update(1) | 687 | progress_bar.update(1) |
684 | local_progress_bar.update(1) | 688 | local_progress_bar.update(1) |
689 | |||
685 | global_step += 1 | 690 | global_step += 1 |
686 | 691 | ||
687 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 692 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
693 | progress_bar.clear() | ||
694 | local_progress_bar.clear() | ||
695 | |||
688 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | 696 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) |
689 | save_resume_file(basepath, args, { | 697 | save_resume_file(basepath, args, { |
690 | "global_step": global_step + global_step_offset, | 698 | "global_step": global_step + global_step_offset, |
@@ -740,6 +748,9 @@ def main(): | |||
740 | 748 | ||
741 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 749 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
742 | 750 | ||
751 | progress_bar.clear() | ||
752 | local_progress_bar.clear() | ||
753 | |||
743 | if min_val_loss > val_loss: | 754 | if min_val_loss > val_loss: |
744 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 755 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
745 | min_val_loss = val_loss | 756 | min_val_loss = val_loss |