diff options
author | Volpeon <git@volpeon.ink> | 2023-01-06 11:14:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-06 11:14:24 +0100 |
commit | 672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch) | |
tree | 1afb3a943af3fa7c935d133cf2768a33f11f8235 /train_ti.py | |
parent | Package update (diff) | |
download | textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2 textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip |
Use context manager for EMA, on_train/eval hooks
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 120 |
1 files changed, 66 insertions, 54 deletions
diff --git a/train_ti.py b/train_ti.py index aa2bf02..f622299 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -2,10 +2,9 @@ import argparse | |||
2 | import math | 2 | import math |
3 | import datetime | 3 | import datetime |
4 | import logging | 4 | import logging |
5 | import copy | ||
6 | from pathlib import Path | ||
7 | from functools import partial | 5 | from functools import partial |
8 | from contextlib import nullcontext | 6 | from pathlib import Path |
7 | from contextlib import contextmanager, nullcontext | ||
9 | 8 | ||
10 | import torch | 9 | import torch |
11 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
@@ -849,11 +848,24 @@ def main(): | |||
849 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 848 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
850 | val_steps = num_val_steps_per_epoch * num_epochs | 849 | val_steps = num_val_steps_per_epoch * num_epochs |
851 | 850 | ||
851 | @contextmanager | ||
852 | def on_train(): | 852 | def on_train(): |
853 | tokenizer.train() | 853 | try: |
854 | tokenizer.train() | ||
855 | yield | ||
856 | finally: | ||
857 | tokenizer.eval() | ||
854 | 858 | ||
859 | @contextmanager | ||
855 | def on_eval(): | 860 | def on_eval(): |
856 | tokenizer.eval() | 861 | try: |
862 | ema_context = ema_embeddings.apply_temporary( | ||
863 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() | ||
864 | |||
865 | with ema_context: | ||
866 | yield | ||
867 | finally: | ||
868 | pass | ||
857 | 869 | ||
858 | loop = partial( | 870 | loop = partial( |
859 | run_model, | 871 | run_model, |
@@ -961,80 +973,80 @@ def main(): | |||
961 | local_progress_bar.reset() | 973 | local_progress_bar.reset() |
962 | 974 | ||
963 | text_encoder.train() | 975 | text_encoder.train() |
964 | on_train() | ||
965 | 976 | ||
966 | for step, batch in enumerate(train_dataloader): | 977 | with on_train(): |
967 | with accelerator.accumulate(text_encoder): | 978 | for step, batch in enumerate(train_dataloader): |
968 | loss, acc, bsz = loop(step, batch) | 979 | with accelerator.accumulate(text_encoder): |
980 | loss, acc, bsz = loop(step, batch) | ||
969 | 981 | ||
970 | accelerator.backward(loss) | 982 | accelerator.backward(loss) |
971 | 983 | ||
972 | optimizer.step() | 984 | optimizer.step() |
973 | if not accelerator.optimizer_step_was_skipped: | 985 | if not accelerator.optimizer_step_was_skipped: |
974 | lr_scheduler.step() | 986 | lr_scheduler.step() |
975 | optimizer.zero_grad(set_to_none=True) | 987 | optimizer.zero_grad(set_to_none=True) |
976 | 988 | ||
977 | avg_loss.update(loss.detach_(), bsz) | 989 | avg_loss.update(loss.detach_(), bsz) |
978 | avg_acc.update(acc.detach_(), bsz) | 990 | avg_acc.update(acc.detach_(), bsz) |
979 | 991 | ||
980 | # Checks if the accelerator has performed an optimization step behind the scenes | 992 | # Checks if the accelerator has performed an optimization step behind the scenes |
981 | if accelerator.sync_gradients: | 993 | if accelerator.sync_gradients: |
982 | if args.use_ema: | 994 | if args.use_ema: |
983 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 995 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
984 | 996 | ||
985 | local_progress_bar.update(1) | 997 | local_progress_bar.update(1) |
986 | global_progress_bar.update(1) | 998 | global_progress_bar.update(1) |
987 | 999 | ||
988 | global_step += 1 | 1000 | global_step += 1 |
989 | 1001 | ||
990 | logs = { | 1002 | logs = { |
991 | "train/loss": avg_loss.avg.item(), | 1003 | "train/loss": avg_loss.avg.item(), |
992 | "train/acc": avg_acc.avg.item(), | 1004 | "train/acc": avg_acc.avg.item(), |
993 | "train/cur_loss": loss.item(), | 1005 | "train/cur_loss": loss.item(), |
994 | "train/cur_acc": acc.item(), | 1006 | "train/cur_acc": acc.item(), |
995 | "lr": lr_scheduler.get_last_lr()[0], | 1007 | "lr": lr_scheduler.get_last_lr()[0], |
996 | } | 1008 | } |
997 | if args.use_ema: | 1009 | if args.use_ema: |
998 | logs["ema_decay"] = ema_embeddings.decay | 1010 | logs["ema_decay"] = ema_embeddings.decay |
999 | 1011 | ||
1000 | accelerator.log(logs, step=global_step) | 1012 | accelerator.log(logs, step=global_step) |
1001 | 1013 | ||
1002 | local_progress_bar.set_postfix(**logs) | 1014 | local_progress_bar.set_postfix(**logs) |
1003 | 1015 | ||
1004 | if global_step >= args.max_train_steps: | 1016 | if global_step >= args.max_train_steps: |
1005 | break | 1017 | break |
1006 | 1018 | ||
1007 | accelerator.wait_for_everyone() | 1019 | accelerator.wait_for_everyone() |
1008 | 1020 | ||
1009 | text_encoder.eval() | 1021 | text_encoder.eval() |
1010 | on_eval() | ||
1011 | 1022 | ||
1012 | cur_loss_val = AverageMeter() | 1023 | cur_loss_val = AverageMeter() |
1013 | cur_acc_val = AverageMeter() | 1024 | cur_acc_val = AverageMeter() |
1014 | 1025 | ||
1015 | with torch.inference_mode(): | 1026 | with torch.inference_mode(): |
1016 | for step, batch in enumerate(val_dataloader): | 1027 | with on_eval(): |
1017 | loss, acc, bsz = loop(step, batch, True) | 1028 | for step, batch in enumerate(val_dataloader): |
1029 | loss, acc, bsz = loop(step, batch, True) | ||
1018 | 1030 | ||
1019 | loss = loss.detach_() | 1031 | loss = loss.detach_() |
1020 | acc = acc.detach_() | 1032 | acc = acc.detach_() |
1021 | 1033 | ||
1022 | cur_loss_val.update(loss, bsz) | 1034 | cur_loss_val.update(loss, bsz) |
1023 | cur_acc_val.update(acc, bsz) | 1035 | cur_acc_val.update(acc, bsz) |
1024 | 1036 | ||
1025 | avg_loss_val.update(loss, bsz) | 1037 | avg_loss_val.update(loss, bsz) |
1026 | avg_acc_val.update(acc, bsz) | 1038 | avg_acc_val.update(acc, bsz) |
1027 | 1039 | ||
1028 | local_progress_bar.update(1) | 1040 | local_progress_bar.update(1) |
1029 | global_progress_bar.update(1) | 1041 | global_progress_bar.update(1) |
1030 | 1042 | ||
1031 | logs = { | 1043 | logs = { |
1032 | "val/loss": avg_loss_val.avg.item(), | 1044 | "val/loss": avg_loss_val.avg.item(), |
1033 | "val/acc": avg_acc_val.avg.item(), | 1045 | "val/acc": avg_acc_val.avg.item(), |
1034 | "val/cur_loss": loss.item(), | 1046 | "val/cur_loss": loss.item(), |
1035 | "val/cur_acc": acc.item(), | 1047 | "val/cur_acc": acc.item(), |
1036 | } | 1048 | } |
1037 | local_progress_bar.set_postfix(**logs) | 1049 | local_progress_bar.set_postfix(**logs) |
1038 | 1050 | ||
1039 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 1051 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
1040 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 1052 | logs["val/cur_acc"] = cur_acc_val.avg.item() |