summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
committerVolpeon <git@volpeon.ink>2022-12-29 09:00:19 +0100
commit4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch)
tree967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /train_ti.py
parentUpdated 1-cycle scheduler (diff)
downloadtextual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2
textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip
Training improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index d7696e5..b1f6a49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -903,12 +903,21 @@ def main():
903 903
904 text_encoder.eval() 904 text_encoder.eval()
905 905
906 cur_loss_val = AverageMeter()
907 cur_acc_val = AverageMeter()
908
906 with torch.inference_mode(): 909 with torch.inference_mode():
907 for step, batch in enumerate(val_dataloader): 910 for step, batch in enumerate(val_dataloader):
908 loss, acc, bsz = loop(batch) 911 loss, acc, bsz = loop(batch)
909 912
910 avg_loss_val.update(loss.detach_(), bsz) 913 loss = loss.detach_()
911 avg_acc_val.update(acc.detach_(), bsz) 914 acc = acc.detach_()
915
916 cur_loss_val.update(loss, bsz)
917 cur_acc_val.update(acc, bsz)
918
919 avg_loss_val.update(loss, bsz)
920 avg_acc_val.update(acc, bsz)
912 921
913 local_progress_bar.update(1) 922 local_progress_bar.update(1)
914 global_progress_bar.update(1) 923 global_progress_bar.update(1)
@@ -921,10 +930,10 @@ def main():
921 } 930 }
922 local_progress_bar.set_postfix(**logs) 931 local_progress_bar.set_postfix(**logs)
923 932
924 accelerator.log({ 933 logs["val/cur_loss"] = cur_loss_val.avg.item()
925 "val/loss": avg_loss_val.avg.item(), 934 logs["val/cur_acc"] = cur_acc_val.avg.item()
926 "val/acc": avg_acc_val.avg.item(), 935
927 }, step=global_step) 936 accelerator.log(logs, step=global_step)
928 937
929 local_progress_bar.clear() 938 local_progress_bar.clear()
930 global_progress_bar.clear() 939 global_progress_bar.clear()