diff options
author | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
commit | 4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch) | |
tree | 967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /train_ti.py | |
parent | Updated 1-cycle scheduler (diff) | |
download | textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2 textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip |
Training improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index d7696e5..b1f6a49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -903,12 +903,21 @@ def main(): | |||
903 | 903 | ||
904 | text_encoder.eval() | 904 | text_encoder.eval() |
905 | 905 | ||
906 | cur_loss_val = AverageMeter() | ||
907 | cur_acc_val = AverageMeter() | ||
908 | |||
906 | with torch.inference_mode(): | 909 | with torch.inference_mode(): |
907 | for step, batch in enumerate(val_dataloader): | 910 | for step, batch in enumerate(val_dataloader): |
908 | loss, acc, bsz = loop(batch) | 911 | loss, acc, bsz = loop(batch) |
909 | 912 | ||
910 | avg_loss_val.update(loss.detach_(), bsz) | 913 | loss = loss.detach_() |
911 | avg_acc_val.update(acc.detach_(), bsz) | 914 | acc = acc.detach_() |
915 | |||
916 | cur_loss_val.update(loss, bsz) | ||
917 | cur_acc_val.update(acc, bsz) | ||
918 | |||
919 | avg_loss_val.update(loss, bsz) | ||
920 | avg_acc_val.update(acc, bsz) | ||
912 | 921 | ||
913 | local_progress_bar.update(1) | 922 | local_progress_bar.update(1) |
914 | global_progress_bar.update(1) | 923 | global_progress_bar.update(1) |
@@ -921,10 +930,10 @@ def main(): | |||
921 | } | 930 | } |
922 | local_progress_bar.set_postfix(**logs) | 931 | local_progress_bar.set_postfix(**logs) |
923 | 932 | ||
924 | accelerator.log({ | 933 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
925 | "val/loss": avg_loss_val.avg.item(), | 934 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
926 | "val/acc": avg_acc_val.avg.item(), | 935 | |
927 | }, step=global_step) | 936 | accelerator.log(logs, step=global_step) |
928 | 937 | ||
929 | local_progress_bar.clear() | 938 | local_progress_bar.clear() |
930 | global_progress_bar.clear() | 939 | global_progress_bar.clear() |