diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-27 16:57:29 +0100 |
| commit | b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch) | |
| tree | 2ad3740868696fc071d8850171e6e53ccc3a7bd2 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2 textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/dreambooth.py b/dreambooth.py index 79b3d2c..2b8a35e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -859,7 +859,14 @@ def main(): | |||
| 859 | # Only show the progress bar once on each machine. | 859 | # Only show the progress bar once on each machine. |
| 860 | 860 | ||
| 861 | global_step = 0 | 861 | global_step = 0 |
| 862 | min_val_loss = np.inf | 862 | |
| 863 | total_loss = 0.0 | ||
| 864 | total_acc = 0.0 | ||
| 865 | |||
| 866 | total_loss_val = 0.0 | ||
| 867 | total_acc_val = 0.0 | ||
| 868 | |||
| 869 | max_acc_val = 0.0 | ||
| 863 | 870 | ||
| 864 | checkpointer = Checkpointer( | 871 | checkpointer = Checkpointer( |
| 865 | datamodule=datamodule, | 872 | datamodule=datamodule, |
| @@ -905,7 +912,6 @@ def main(): | |||
| 905 | 912 | ||
| 906 | unet.train() | 913 | unet.train() |
| 907 | text_encoder.train() | 914 | text_encoder.train() |
| 908 | train_loss = 0.0 | ||
| 909 | 915 | ||
| 910 | sample_checkpoint = False | 916 | sample_checkpoint = False |
| 911 | 917 | ||
| @@ -978,8 +984,11 @@ def main(): | |||
| 978 | ema_unet.step(unet) | 984 | ema_unet.step(unet) |
| 979 | optimizer.zero_grad(set_to_none=True) | 985 | optimizer.zero_grad(set_to_none=True) |
| 980 | 986 | ||
| 981 | loss = loss.detach().item() | 987 | acc = (noise_pred == latents).float() |
| 982 | train_loss += loss | 988 | acc = acc.mean() |
| 989 | |||
| 990 | total_loss += loss.item() | ||
| 991 | total_acc += acc.item() | ||
| 983 | 992 | ||
| 984 | # Checks if the accelerator has performed an optimization step behind the scenes | 993 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 985 | if accelerator.sync_gradients: | 994 | if accelerator.sync_gradients: |
| @@ -996,7 +1005,10 @@ def main(): | |||
| 996 | sample_checkpoint = True | 1005 | sample_checkpoint = True |
| 997 | 1006 | ||
| 998 | logs = { | 1007 | logs = { |
| 999 | "train/loss": loss, | 1008 | "train/loss": total_loss / global_step, |
| 1009 | "train/acc": total_acc / global_step, | ||
| 1010 | "train/cur_loss": loss.item(), | ||
| 1011 | "train/cur_acc": acc.item(), | ||
| 1000 | "lr/unet": lr_scheduler.get_last_lr()[0], | 1012 | "lr/unet": lr_scheduler.get_last_lr()[0], |
| 1001 | "lr/text": lr_scheduler.get_last_lr()[1] | 1013 | "lr/text": lr_scheduler.get_last_lr()[1] |
| 1002 | } | 1014 | } |
| @@ -1010,13 +1022,10 @@ def main(): | |||
| 1010 | if global_step >= args.max_train_steps: | 1022 | if global_step >= args.max_train_steps: |
| 1011 | break | 1023 | break |
| 1012 | 1024 | ||
| 1013 | train_loss /= len(train_dataloader) | ||
| 1014 | |||
| 1015 | accelerator.wait_for_everyone() | 1025 | accelerator.wait_for_everyone() |
| 1016 | 1026 | ||
| 1017 | unet.eval() | 1027 | unet.eval() |
| 1018 | text_encoder.eval() | 1028 | text_encoder.eval() |
| 1019 | val_loss = 0.0 | ||
| 1020 | 1029 | ||
| 1021 | with torch.autocast("cuda"), torch.inference_mode(): | 1030 | with torch.autocast("cuda"), torch.inference_mode(): |
| 1022 | for step, batch in enumerate(val_dataloader): | 1031 | for step, batch in enumerate(val_dataloader): |
| @@ -1039,28 +1048,41 @@ def main(): | |||
| 1039 | 1048 | ||
| 1040 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 1049 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 1041 | 1050 | ||
| 1042 | loss = loss.detach().item() | 1051 | acc = (noise_pred == latents).float() |
| 1043 | val_loss += loss | 1052 | acc = acc.mean() |
| 1053 | |||
| 1054 | total_loss_val += loss.item() | ||
| 1055 | total_acc_val += acc.item() | ||
| 1044 | 1056 | ||
| 1045 | if accelerator.sync_gradients: | 1057 | if accelerator.sync_gradients: |
| 1046 | local_progress_bar.update(1) | 1058 | local_progress_bar.update(1) |
| 1047 | global_progress_bar.update(1) | 1059 | global_progress_bar.update(1) |
| 1048 | 1060 | ||
| 1049 | logs = {"val/loss": loss} | 1061 | logs = { |
| 1062 | "val/loss": total_loss_val / global_step, | ||
| 1063 | "val/acc": total_acc_val / global_step, | ||
| 1064 | "val/cur_loss": loss.item(), | ||
| 1065 | "val/cur_acc": acc.item(), | ||
| 1066 | } | ||
| 1050 | local_progress_bar.set_postfix(**logs) | 1067 | local_progress_bar.set_postfix(**logs) |
| 1051 | 1068 | ||
| 1052 | val_loss /= len(val_dataloader) | 1069 | val_step = (epoch + 1) * len(val_dataloader) |
| 1070 | avg_acc_val = total_acc_val / val_step | ||
| 1071 | avg_loss_val = total_loss_val / val_step | ||
| 1053 | 1072 | ||
| 1054 | accelerator.log({"val/loss": val_loss}, step=global_step) | 1073 | accelerator.log({ |
| 1074 | "val/loss": avg_loss_val, | ||
| 1075 | "val/acc": avg_acc_val, | ||
| 1076 | }, step=global_step) | ||
| 1055 | 1077 | ||
| 1056 | local_progress_bar.clear() | 1078 | local_progress_bar.clear() |
| 1057 | global_progress_bar.clear() | 1079 | global_progress_bar.clear() |
| 1058 | 1080 | ||
| 1059 | if min_val_loss > val_loss: | 1081 | if avg_acc_val > max_acc_val: |
| 1060 | accelerator.print( | 1082 | accelerator.print( |
| 1061 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1083 | f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") |
| 1062 | checkpointer.save_embedding(global_step, "milestone") | 1084 | checkpointer.save_embedding(global_step, "milestone") |
| 1063 | min_val_loss = val_loss | 1085 | max_acc_val = avg_acc_val |
| 1064 | 1086 | ||
| 1065 | if sample_checkpoint and accelerator.is_main_process: | 1087 | if sample_checkpoint and accelerator.is_main_process: |
| 1066 | checkpointer.save_samples( | 1088 | checkpointer.save_samples( |
