summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
commit672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch)
tree1afb3a943af3fa7c935d133cf2768a33f11f8235 /train_ti.py
parentPackage update (diff)
downloadtextual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip
Use context manager for EMA, on_train/eval hooks
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py120
1 files changed, 66 insertions, 54 deletions
diff --git a/train_ti.py b/train_ti.py
index aa2bf02..f622299 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -2,10 +2,9 @@ import argparse
2import math 2import math
3import datetime 3import datetime
4import logging 4import logging
5import copy
6from pathlib import Path
7from functools import partial 5from functools import partial
8from contextlib import nullcontext 6from pathlib import Path
7from contextlib import contextmanager, nullcontext
9 8
10import torch 9import torch
11import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -849,11 +848,24 @@ def main():
849 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 848 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
850 val_steps = num_val_steps_per_epoch * num_epochs 849 val_steps = num_val_steps_per_epoch * num_epochs
851 850
851 @contextmanager
852 def on_train(): 852 def on_train():
853 tokenizer.train() 853 try:
854 tokenizer.train()
855 yield
856 finally:
857 tokenizer.eval()
854 858
859 @contextmanager
855 def on_eval(): 860 def on_eval():
856 tokenizer.eval() 861 try:
862 ema_context = ema_embeddings.apply_temporary(
863 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext()
864
865 with ema_context:
866 yield
867 finally:
868 pass
857 869
858 loop = partial( 870 loop = partial(
859 run_model, 871 run_model,
@@ -961,80 +973,80 @@ def main():
961 local_progress_bar.reset() 973 local_progress_bar.reset()
962 974
963 text_encoder.train() 975 text_encoder.train()
964 on_train()
965 976
966 for step, batch in enumerate(train_dataloader): 977 with on_train():
967 with accelerator.accumulate(text_encoder): 978 for step, batch in enumerate(train_dataloader):
968 loss, acc, bsz = loop(step, batch) 979 with accelerator.accumulate(text_encoder):
980 loss, acc, bsz = loop(step, batch)
969 981
970 accelerator.backward(loss) 982 accelerator.backward(loss)
971 983
972 optimizer.step() 984 optimizer.step()
973 if not accelerator.optimizer_step_was_skipped: 985 if not accelerator.optimizer_step_was_skipped:
974 lr_scheduler.step() 986 lr_scheduler.step()
975 optimizer.zero_grad(set_to_none=True) 987 optimizer.zero_grad(set_to_none=True)
976 988
977 avg_loss.update(loss.detach_(), bsz) 989 avg_loss.update(loss.detach_(), bsz)
978 avg_acc.update(acc.detach_(), bsz) 990 avg_acc.update(acc.detach_(), bsz)
979 991
980 # Checks if the accelerator has performed an optimization step behind the scenes 992 # Checks if the accelerator has performed an optimization step behind the scenes
981 if accelerator.sync_gradients: 993 if accelerator.sync_gradients:
982 if args.use_ema: 994 if args.use_ema:
983 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 995 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
984 996
985 local_progress_bar.update(1) 997 local_progress_bar.update(1)
986 global_progress_bar.update(1) 998 global_progress_bar.update(1)
987 999
988 global_step += 1 1000 global_step += 1
989 1001
990 logs = { 1002 logs = {
991 "train/loss": avg_loss.avg.item(), 1003 "train/loss": avg_loss.avg.item(),
992 "train/acc": avg_acc.avg.item(), 1004 "train/acc": avg_acc.avg.item(),
993 "train/cur_loss": loss.item(), 1005 "train/cur_loss": loss.item(),
994 "train/cur_acc": acc.item(), 1006 "train/cur_acc": acc.item(),
995 "lr": lr_scheduler.get_last_lr()[0], 1007 "lr": lr_scheduler.get_last_lr()[0],
996 } 1008 }
997 if args.use_ema: 1009 if args.use_ema:
998 logs["ema_decay"] = ema_embeddings.decay 1010 logs["ema_decay"] = ema_embeddings.decay
999 1011
1000 accelerator.log(logs, step=global_step) 1012 accelerator.log(logs, step=global_step)
1001 1013
1002 local_progress_bar.set_postfix(**logs) 1014 local_progress_bar.set_postfix(**logs)
1003 1015
1004 if global_step >= args.max_train_steps: 1016 if global_step >= args.max_train_steps:
1005 break 1017 break
1006 1018
1007 accelerator.wait_for_everyone() 1019 accelerator.wait_for_everyone()
1008 1020
1009 text_encoder.eval() 1021 text_encoder.eval()
1010 on_eval()
1011 1022
1012 cur_loss_val = AverageMeter() 1023 cur_loss_val = AverageMeter()
1013 cur_acc_val = AverageMeter() 1024 cur_acc_val = AverageMeter()
1014 1025
1015 with torch.inference_mode(): 1026 with torch.inference_mode():
1016 for step, batch in enumerate(val_dataloader): 1027 with on_eval():
1017 loss, acc, bsz = loop(step, batch, True) 1028 for step, batch in enumerate(val_dataloader):
1029 loss, acc, bsz = loop(step, batch, True)
1018 1030
1019 loss = loss.detach_() 1031 loss = loss.detach_()
1020 acc = acc.detach_() 1032 acc = acc.detach_()
1021 1033
1022 cur_loss_val.update(loss, bsz) 1034 cur_loss_val.update(loss, bsz)
1023 cur_acc_val.update(acc, bsz) 1035 cur_acc_val.update(acc, bsz)
1024 1036
1025 avg_loss_val.update(loss, bsz) 1037 avg_loss_val.update(loss, bsz)
1026 avg_acc_val.update(acc, bsz) 1038 avg_acc_val.update(acc, bsz)
1027 1039
1028 local_progress_bar.update(1) 1040 local_progress_bar.update(1)
1029 global_progress_bar.update(1) 1041 global_progress_bar.update(1)
1030 1042
1031 logs = { 1043 logs = {
1032 "val/loss": avg_loss_val.avg.item(), 1044 "val/loss": avg_loss_val.avg.item(),
1033 "val/acc": avg_acc_val.avg.item(), 1045 "val/acc": avg_acc_val.avg.item(),
1034 "val/cur_loss": loss.item(), 1046 "val/cur_loss": loss.item(),
1035 "val/cur_acc": acc.item(), 1047 "val/cur_acc": acc.item(),
1036 } 1048 }
1037 local_progress_bar.set_postfix(**logs) 1049 local_progress_bar.set_postfix(**logs)
1038 1050
1039 logs["val/cur_loss"] = cur_loss_val.avg.item() 1051 logs["val/cur_loss"] = cur_loss_val.avg.item()
1040 logs["val/cur_acc"] = cur_acc_val.avg.item() 1052 logs["val/cur_acc"] = cur_acc_val.avg.item()