diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/dreambooth.py b/dreambooth.py index 79b3d2c..2b8a35e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -859,7 +859,14 @@ def main(): | |||
859 | # Only show the progress bar once on each machine. | 859 | # Only show the progress bar once on each machine. |
860 | 860 | ||
861 | global_step = 0 | 861 | global_step = 0 |
862 | min_val_loss = np.inf | 862 | |
863 | total_loss = 0.0 | ||
864 | total_acc = 0.0 | ||
865 | |||
866 | total_loss_val = 0.0 | ||
867 | total_acc_val = 0.0 | ||
868 | |||
869 | max_acc_val = 0.0 | ||
863 | 870 | ||
864 | checkpointer = Checkpointer( | 871 | checkpointer = Checkpointer( |
865 | datamodule=datamodule, | 872 | datamodule=datamodule, |
@@ -905,7 +912,6 @@ def main(): | |||
905 | 912 | ||
906 | unet.train() | 913 | unet.train() |
907 | text_encoder.train() | 914 | text_encoder.train() |
908 | train_loss = 0.0 | ||
909 | 915 | ||
910 | sample_checkpoint = False | 916 | sample_checkpoint = False |
911 | 917 | ||
@@ -978,8 +984,11 @@ def main(): | |||
978 | ema_unet.step(unet) | 984 | ema_unet.step(unet) |
979 | optimizer.zero_grad(set_to_none=True) | 985 | optimizer.zero_grad(set_to_none=True) |
980 | 986 | ||
981 | loss = loss.detach().item() | 987 | acc = (noise_pred == latents).float() |
982 | train_loss += loss | 988 | acc = acc.mean() |
989 | |||
990 | total_loss += loss.item() | ||
991 | total_acc += acc.item() | ||
983 | 992 | ||
984 | # Checks if the accelerator has performed an optimization step behind the scenes | 993 | # Checks if the accelerator has performed an optimization step behind the scenes |
985 | if accelerator.sync_gradients: | 994 | if accelerator.sync_gradients: |
@@ -996,7 +1005,10 @@ def main(): | |||
996 | sample_checkpoint = True | 1005 | sample_checkpoint = True |
997 | 1006 | ||
998 | logs = { | 1007 | logs = { |
999 | "train/loss": loss, | 1008 | "train/loss": total_loss / global_step, |
1009 | "train/acc": total_acc / global_step, | ||
1010 | "train/cur_loss": loss.item(), | ||
1011 | "train/cur_acc": acc.item(), | ||
1000 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1012 | "lr/unet": lr_scheduler.get_last_lr()[0], |
1001 | "lr/text": lr_scheduler.get_last_lr()[1] | 1013 | "lr/text": lr_scheduler.get_last_lr()[1] |
1002 | } | 1014 | } |
@@ -1010,13 +1022,10 @@ def main(): | |||
1010 | if global_step >= args.max_train_steps: | 1022 | if global_step >= args.max_train_steps: |
1011 | break | 1023 | break |
1012 | 1024 | ||
1013 | train_loss /= len(train_dataloader) | ||
1014 | |||
1015 | accelerator.wait_for_everyone() | 1025 | accelerator.wait_for_everyone() |
1016 | 1026 | ||
1017 | unet.eval() | 1027 | unet.eval() |
1018 | text_encoder.eval() | 1028 | text_encoder.eval() |
1019 | val_loss = 0.0 | ||
1020 | 1029 | ||
1021 | with torch.autocast("cuda"), torch.inference_mode(): | 1030 | with torch.autocast("cuda"), torch.inference_mode(): |
1022 | for step, batch in enumerate(val_dataloader): | 1031 | for step, batch in enumerate(val_dataloader): |
@@ -1039,28 +1048,41 @@ def main(): | |||
1039 | 1048 | ||
1040 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1049 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
1041 | 1050 | ||
1042 | loss = loss.detach().item() | 1051 | acc = (noise_pred == latents).float() |
1043 | val_loss += loss | 1052 | acc = acc.mean() |
1053 | |||
1054 | total_loss_val += loss.item() | ||
1055 | total_acc_val += acc.item() | ||
1044 | 1056 | ||
1045 | if accelerator.sync_gradients: | 1057 | if accelerator.sync_gradients: |
1046 | local_progress_bar.update(1) | 1058 | local_progress_bar.update(1) |
1047 | global_progress_bar.update(1) | 1059 | global_progress_bar.update(1) |
1048 | 1060 | ||
1049 | logs = {"val/loss": loss} | 1061 | logs = { |
1062 | "val/loss": total_loss_val / global_step, | ||
1063 | "val/acc": total_acc_val / global_step, | ||
1064 | "val/cur_loss": loss.item(), | ||
1065 | "val/cur_acc": acc.item(), | ||
1066 | } | ||
1050 | local_progress_bar.set_postfix(**logs) | 1067 | local_progress_bar.set_postfix(**logs) |
1051 | 1068 | ||
1052 | val_loss /= len(val_dataloader) | 1069 | val_step = (epoch + 1) * len(val_dataloader) |
1070 | avg_acc_val = total_acc_val / val_step | ||
1071 | avg_loss_val = total_loss_val / val_step | ||
1053 | 1072 | ||
1054 | accelerator.log({"val/loss": val_loss}, step=global_step) | 1073 | accelerator.log({ |
1074 | "val/loss": avg_loss_val, | ||
1075 | "val/acc": avg_acc_val, | ||
1076 | }, step=global_step) | ||
1055 | 1077 | ||
1056 | local_progress_bar.clear() | 1078 | local_progress_bar.clear() |
1057 | global_progress_bar.clear() | 1079 | global_progress_bar.clear() |
1058 | 1080 | ||
1059 | if min_val_loss > val_loss: | 1081 | if avg_acc_val > max_acc_val: |
1060 | accelerator.print( | 1082 | accelerator.print( |
1061 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1083 | f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") |
1062 | checkpointer.save_embedding(global_step, "milestone") | 1084 | checkpointer.save_embedding(global_step, "milestone") |
1063 | min_val_loss = val_loss | 1085 | max_acc_val = avg_acc_val |
1064 | 1086 | ||
1065 | if sample_checkpoint and accelerator.is_main_process: | 1087 | if sample_checkpoint and accelerator.is_main_process: |
1066 | checkpointer.save_samples( | 1088 | checkpointer.save_samples( |