From b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 16:57:29 +0100 Subject: Update --- dreambooth.py | 54 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 16 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 79b3d2c..2b8a35e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -859,7 +859,14 @@ def main(): # Only show the progress bar once on each machine. global_step = 0 - min_val_loss = np.inf + + total_loss = 0.0 + total_acc = 0.0 + + total_loss_val = 0.0 + total_acc_val = 0.0 + + max_acc_val = 0.0 checkpointer = Checkpointer( datamodule=datamodule, @@ -905,7 +912,6 @@ def main(): unet.train() text_encoder.train() - train_loss = 0.0 sample_checkpoint = False @@ -978,8 +984,11 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - loss = loss.detach().item() - train_loss += loss + acc = (noise_pred == latents).float() + acc = acc.mean() + + total_loss += loss.item() + total_acc += acc.item() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -996,7 +1005,10 @@ def main(): sample_checkpoint = True logs = { - "train/loss": loss, + "train/loss": total_loss / global_step, + "train/acc": total_acc / global_step, + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), "lr/unet": lr_scheduler.get_last_lr()[0], "lr/text": lr_scheduler.get_last_lr()[1] } @@ -1010,13 +1022,10 @@ def main(): if global_step >= args.max_train_steps: break - train_loss /= len(train_dataloader) - accelerator.wait_for_everyone() unet.eval() text_encoder.eval() - val_loss = 0.0 with torch.autocast("cuda"), torch.inference_mode(): for step, batch in enumerate(val_dataloader): @@ -1039,28 +1048,41 @@ def main(): loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - loss = loss.detach().item() - val_loss += loss + acc = (noise_pred == latents).float() + acc = acc.mean() + + total_loss_val += loss.item() + total_acc_val += acc.item() if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"val/loss": loss} + logs = { + "val/loss": total_loss_val / global_step, + "val/acc": total_acc_val / global_step, + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } local_progress_bar.set_postfix(**logs) - val_loss /= len(val_dataloader) + val_step = (epoch + 1) * len(val_dataloader) + avg_acc_val = total_acc_val / val_step + avg_loss_val = total_loss_val / val_step - accelerator.log({"val/loss": val_loss}, step=global_step) + accelerator.log({ + "val/loss": avg_loss_val, + "val/acc": avg_acc_val, + }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if min_val_loss > val_loss: + if avg_acc_val > max_acc_val: accelerator.print( - f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") checkpointer.save_embedding(global_step, "milestone") - min_val_loss = val_loss + max_acc_val = avg_acc_val if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( -- cgit v1.2.3-54-g00ecf