diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-29 09:00:19 +0100 |
| commit | 4d3d318a4168ef79847737cef2c0ad8a4dafd3e7 (patch) | |
| tree | 967e2c1ee6e2c29b9b6ffaff3e8978f4a43a529d /train_ti.py | |
| parent | Updated 1-cycle scheduler (diff) | |
| download | textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.gz textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.tar.bz2 textual-inversion-diff-4d3d318a4168ef79847737cef2c0ad8a4dafd3e7.zip | |
Training improvements
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index d7696e5..b1f6a49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -903,12 +903,21 @@ def main(): | |||
| 903 | 903 | ||
| 904 | text_encoder.eval() | 904 | text_encoder.eval() |
| 905 | 905 | ||
| 906 | cur_loss_val = AverageMeter() | ||
| 907 | cur_acc_val = AverageMeter() | ||
| 908 | |||
| 906 | with torch.inference_mode(): | 909 | with torch.inference_mode(): |
| 907 | for step, batch in enumerate(val_dataloader): | 910 | for step, batch in enumerate(val_dataloader): |
| 908 | loss, acc, bsz = loop(batch) | 911 | loss, acc, bsz = loop(batch) |
| 909 | 912 | ||
| 910 | avg_loss_val.update(loss.detach_(), bsz) | 913 | loss = loss.detach_() |
| 911 | avg_acc_val.update(acc.detach_(), bsz) | 914 | acc = acc.detach_() |
| 915 | |||
| 916 | cur_loss_val.update(loss, bsz) | ||
| 917 | cur_acc_val.update(acc, bsz) | ||
| 918 | |||
| 919 | avg_loss_val.update(loss, bsz) | ||
| 920 | avg_acc_val.update(acc, bsz) | ||
| 912 | 921 | ||
| 913 | local_progress_bar.update(1) | 922 | local_progress_bar.update(1) |
| 914 | global_progress_bar.update(1) | 923 | global_progress_bar.update(1) |
| @@ -921,10 +930,10 @@ def main(): | |||
| 921 | } | 930 | } |
| 922 | local_progress_bar.set_postfix(**logs) | 931 | local_progress_bar.set_postfix(**logs) |
| 923 | 932 | ||
| 924 | accelerator.log({ | 933 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 925 | "val/loss": avg_loss_val.avg.item(), | 934 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
| 926 | "val/acc": avg_acc_val.avg.item(), | 935 | |
| 927 | }, step=global_step) | 936 | accelerator.log(logs, step=global_step) |
| 928 | 937 | ||
| 929 | local_progress_bar.clear() | 938 | local_progress_bar.clear() |
| 930 | global_progress_bar.clear() | 939 | global_progress_bar.clear() |
