summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
commitb9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch)
tree2ad3740868696fc071d8850171e6e53ccc3a7bd2 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py54
1 files changed, 38 insertions, 16 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 79b3d2c..2b8a35e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -859,7 +859,14 @@ def main():
859 # Only show the progress bar once on each machine. 859 # Only show the progress bar once on each machine.
860 860
861 global_step = 0 861 global_step = 0
862 min_val_loss = np.inf 862
863 total_loss = 0.0
864 total_acc = 0.0
865
866 total_loss_val = 0.0
867 total_acc_val = 0.0
868
869 max_acc_val = 0.0
863 870
864 checkpointer = Checkpointer( 871 checkpointer = Checkpointer(
865 datamodule=datamodule, 872 datamodule=datamodule,
@@ -905,7 +912,6 @@ def main():
905 912
906 unet.train() 913 unet.train()
907 text_encoder.train() 914 text_encoder.train()
908 train_loss = 0.0
909 915
910 sample_checkpoint = False 916 sample_checkpoint = False
911 917
@@ -978,8 +984,11 @@ def main():
978 ema_unet.step(unet) 984 ema_unet.step(unet)
979 optimizer.zero_grad(set_to_none=True) 985 optimizer.zero_grad(set_to_none=True)
980 986
981 loss = loss.detach().item() 987 acc = (noise_pred == latents).float()
982 train_loss += loss 988 acc = acc.mean()
989
990 total_loss += loss.item()
991 total_acc += acc.item()
983 992
984 # Checks if the accelerator has performed an optimization step behind the scenes 993 # Checks if the accelerator has performed an optimization step behind the scenes
985 if accelerator.sync_gradients: 994 if accelerator.sync_gradients:
@@ -996,7 +1005,10 @@ def main():
996 sample_checkpoint = True 1005 sample_checkpoint = True
997 1006
998 logs = { 1007 logs = {
999 "train/loss": loss, 1008 "train/loss": total_loss / global_step,
1009 "train/acc": total_acc / global_step,
1010 "train/cur_loss": loss.item(),
1011 "train/cur_acc": acc.item(),
1000 "lr/unet": lr_scheduler.get_last_lr()[0], 1012 "lr/unet": lr_scheduler.get_last_lr()[0],
1001 "lr/text": lr_scheduler.get_last_lr()[1] 1013 "lr/text": lr_scheduler.get_last_lr()[1]
1002 } 1014 }
@@ -1010,13 +1022,10 @@ def main():
1010 if global_step >= args.max_train_steps: 1022 if global_step >= args.max_train_steps:
1011 break 1023 break
1012 1024
1013 train_loss /= len(train_dataloader)
1014
1015 accelerator.wait_for_everyone() 1025 accelerator.wait_for_everyone()
1016 1026
1017 unet.eval() 1027 unet.eval()
1018 text_encoder.eval() 1028 text_encoder.eval()
1019 val_loss = 0.0
1020 1029
1021 with torch.autocast("cuda"), torch.inference_mode(): 1030 with torch.autocast("cuda"), torch.inference_mode():
1022 for step, batch in enumerate(val_dataloader): 1031 for step, batch in enumerate(val_dataloader):
@@ -1039,28 +1048,41 @@ def main():
1039 1048
1040 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 1049 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1041 1050
1042 loss = loss.detach().item() 1051 acc = (noise_pred == latents).float()
1043 val_loss += loss 1052 acc = acc.mean()
1053
1054 total_loss_val += loss.item()
1055 total_acc_val += acc.item()
1044 1056
1045 if accelerator.sync_gradients: 1057 if accelerator.sync_gradients:
1046 local_progress_bar.update(1) 1058 local_progress_bar.update(1)
1047 global_progress_bar.update(1) 1059 global_progress_bar.update(1)
1048 1060
1049 logs = {"val/loss": loss} 1061 logs = {
1062 "val/loss": total_loss_val / global_step,
1063 "val/acc": total_acc_val / global_step,
1064 "val/cur_loss": loss.item(),
1065 "val/cur_acc": acc.item(),
1066 }
1050 local_progress_bar.set_postfix(**logs) 1067 local_progress_bar.set_postfix(**logs)
1051 1068
1052 val_loss /= len(val_dataloader) 1069 val_step = (epoch + 1) * len(val_dataloader)
1070 avg_acc_val = total_acc_val / val_step
1071 avg_loss_val = total_loss_val / val_step
1053 1072
1054 accelerator.log({"val/loss": val_loss}, step=global_step) 1073 accelerator.log({
1074 "val/loss": avg_loss_val,
1075 "val/acc": avg_acc_val,
1076 }, step=global_step)
1055 1077
1056 local_progress_bar.clear() 1078 local_progress_bar.clear()
1057 global_progress_bar.clear() 1079 global_progress_bar.clear()
1058 1080
1059 if min_val_loss > val_loss: 1081 if avg_acc_val > max_acc_val:
1060 accelerator.print( 1082 accelerator.print(
1061 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 1083 f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}")
1062 checkpointer.save_embedding(global_step, "milestone") 1084 checkpointer.save_embedding(global_step, "milestone")
1063 min_val_loss = val_loss 1085 max_acc_val = avg_acc_val
1064 1086
1065 if sample_checkpoint and accelerator.is_main_process: 1087 if sample_checkpoint and accelerator.is_main_process:
1066 checkpointer.save_samples( 1088 checkpointer.save_samples(