summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-06 11:14:24 +0100
commit672a59abeaa60dc5ef78a33bd9b58e391b922016 (patch)
tree1afb3a943af3fa7c935d133cf2768a33f11f8235
parentPackage update (diff)
downloadtextual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.gz
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.tar.bz2
textual-inversion-diff-672a59abeaa60dc5ef78a33bd9b58e391b922016.zip
Use context manager for EMA, on_train/eval hooks
-rw-r--r--train_ti.py120
-rw-r--r--training/lr.py51
-rw-r--r--training/util.py2
3 files changed, 92 insertions, 81 deletions
diff --git a/train_ti.py b/train_ti.py
index aa2bf02..f622299 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -2,10 +2,9 @@ import argparse
2import math 2import math
3import datetime 3import datetime
4import logging 4import logging
5import copy
6from pathlib import Path
7from functools import partial 5from functools import partial
8from contextlib import nullcontext 6from pathlib import Path
7from contextlib import contextmanager, nullcontext
9 8
10import torch 9import torch
11import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -849,11 +848,24 @@ def main():
849 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 848 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
850 val_steps = num_val_steps_per_epoch * num_epochs 849 val_steps = num_val_steps_per_epoch * num_epochs
851 850
851 @contextmanager
852 def on_train(): 852 def on_train():
853 tokenizer.train() 853 try:
854 tokenizer.train()
855 yield
856 finally:
857 tokenizer.eval()
854 858
859 @contextmanager
855 def on_eval(): 860 def on_eval():
856 tokenizer.eval() 861 try:
862 ema_context = ema_embeddings.apply_temporary(
863 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext()
864
865 with ema_context:
866 yield
867 finally:
868 pass
857 869
858 loop = partial( 870 loop = partial(
859 run_model, 871 run_model,
@@ -961,80 +973,80 @@ def main():
961 local_progress_bar.reset() 973 local_progress_bar.reset()
962 974
963 text_encoder.train() 975 text_encoder.train()
964 on_train()
965 976
966 for step, batch in enumerate(train_dataloader): 977 with on_train():
967 with accelerator.accumulate(text_encoder): 978 for step, batch in enumerate(train_dataloader):
968 loss, acc, bsz = loop(step, batch) 979 with accelerator.accumulate(text_encoder):
980 loss, acc, bsz = loop(step, batch)
969 981
970 accelerator.backward(loss) 982 accelerator.backward(loss)
971 983
972 optimizer.step() 984 optimizer.step()
973 if not accelerator.optimizer_step_was_skipped: 985 if not accelerator.optimizer_step_was_skipped:
974 lr_scheduler.step() 986 lr_scheduler.step()
975 optimizer.zero_grad(set_to_none=True) 987 optimizer.zero_grad(set_to_none=True)
976 988
977 avg_loss.update(loss.detach_(), bsz) 989 avg_loss.update(loss.detach_(), bsz)
978 avg_acc.update(acc.detach_(), bsz) 990 avg_acc.update(acc.detach_(), bsz)
979 991
980 # Checks if the accelerator has performed an optimization step behind the scenes 992 # Checks if the accelerator has performed an optimization step behind the scenes
981 if accelerator.sync_gradients: 993 if accelerator.sync_gradients:
982 if args.use_ema: 994 if args.use_ema:
983 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 995 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
984 996
985 local_progress_bar.update(1) 997 local_progress_bar.update(1)
986 global_progress_bar.update(1) 998 global_progress_bar.update(1)
987 999
988 global_step += 1 1000 global_step += 1
989 1001
990 logs = { 1002 logs = {
991 "train/loss": avg_loss.avg.item(), 1003 "train/loss": avg_loss.avg.item(),
992 "train/acc": avg_acc.avg.item(), 1004 "train/acc": avg_acc.avg.item(),
993 "train/cur_loss": loss.item(), 1005 "train/cur_loss": loss.item(),
994 "train/cur_acc": acc.item(), 1006 "train/cur_acc": acc.item(),
995 "lr": lr_scheduler.get_last_lr()[0], 1007 "lr": lr_scheduler.get_last_lr()[0],
996 } 1008 }
997 if args.use_ema: 1009 if args.use_ema:
998 logs["ema_decay"] = ema_embeddings.decay 1010 logs["ema_decay"] = ema_embeddings.decay
999 1011
1000 accelerator.log(logs, step=global_step) 1012 accelerator.log(logs, step=global_step)
1001 1013
1002 local_progress_bar.set_postfix(**logs) 1014 local_progress_bar.set_postfix(**logs)
1003 1015
1004 if global_step >= args.max_train_steps: 1016 if global_step >= args.max_train_steps:
1005 break 1017 break
1006 1018
1007 accelerator.wait_for_everyone() 1019 accelerator.wait_for_everyone()
1008 1020
1009 text_encoder.eval() 1021 text_encoder.eval()
1010 on_eval()
1011 1022
1012 cur_loss_val = AverageMeter() 1023 cur_loss_val = AverageMeter()
1013 cur_acc_val = AverageMeter() 1024 cur_acc_val = AverageMeter()
1014 1025
1015 with torch.inference_mode(): 1026 with torch.inference_mode():
1016 for step, batch in enumerate(val_dataloader): 1027 with on_eval():
1017 loss, acc, bsz = loop(step, batch, True) 1028 for step, batch in enumerate(val_dataloader):
1029 loss, acc, bsz = loop(step, batch, True)
1018 1030
1019 loss = loss.detach_() 1031 loss = loss.detach_()
1020 acc = acc.detach_() 1032 acc = acc.detach_()
1021 1033
1022 cur_loss_val.update(loss, bsz) 1034 cur_loss_val.update(loss, bsz)
1023 cur_acc_val.update(acc, bsz) 1035 cur_acc_val.update(acc, bsz)
1024 1036
1025 avg_loss_val.update(loss, bsz) 1037 avg_loss_val.update(loss, bsz)
1026 avg_acc_val.update(acc, bsz) 1038 avg_acc_val.update(acc, bsz)
1027 1039
1028 local_progress_bar.update(1) 1040 local_progress_bar.update(1)
1029 global_progress_bar.update(1) 1041 global_progress_bar.update(1)
1030 1042
1031 logs = { 1043 logs = {
1032 "val/loss": avg_loss_val.avg.item(), 1044 "val/loss": avg_loss_val.avg.item(),
1033 "val/acc": avg_acc_val.avg.item(), 1045 "val/acc": avg_acc_val.avg.item(),
1034 "val/cur_loss": loss.item(), 1046 "val/cur_loss": loss.item(),
1035 "val/cur_acc": acc.item(), 1047 "val/cur_acc": acc.item(),
1036 } 1048 }
1037 local_progress_bar.set_postfix(**logs) 1049 local_progress_bar.set_postfix(**logs)
1038 1050
1039 logs["val/cur_loss"] = cur_loss_val.avg.item() 1051 logs["val/cur_loss"] = cur_loss_val.avg.item()
1040 logs["val/cur_acc"] = cur_acc_val.avg.item() 1052 logs["val/cur_acc"] = cur_acc_val.avg.item()
diff --git a/training/lr.py b/training/lr.py
index c765150..68e0f72 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,5 +1,5 @@
1import math 1import math
2import copy 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 3from typing import Callable, Any, Tuple, Union
4from functools import partial 4from functools import partial
5 5
@@ -25,9 +25,9 @@ class LRFinder():
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
29 on_clip: Callable[[], None] = noop, 29 on_clip: Callable[[], None] = noop,
30 on_eval: Callable[[], None] = noop 30 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
31 ): 31 ):
32 self.accelerator = accelerator 32 self.accelerator = accelerator
33 self.model = model 33 self.model = model
@@ -51,7 +51,6 @@ class LRFinder():
51 num_train_batches: int = 1, 51 num_train_batches: int = 1,
52 num_val_batches: int = math.inf, 52 num_val_batches: int = math.inf,
53 smooth_f: float = 0.05, 53 smooth_f: float = 0.05,
54 diverge_th: int = 5,
55 ): 54 ):
56 best_loss = None 55 best_loss = None
57 best_acc = None 56 best_acc = None
@@ -84,40 +83,40 @@ class LRFinder():
84 avg_acc = AverageMeter() 83 avg_acc = AverageMeter()
85 84
86 self.model.train() 85 self.model.train()
87 self.on_train()
88 86
89 for step, batch in enumerate(self.train_dataloader): 87 with self.on_train():
90 if step >= num_train_batches: 88 for step, batch in enumerate(self.train_dataloader):
91 break 89 if step >= num_train_batches:
90 break
92 91
93 with self.accelerator.accumulate(self.model): 92 with self.accelerator.accumulate(self.model):
94 loss, acc, bsz = self.loss_fn(step, batch) 93 loss, acc, bsz = self.loss_fn(step, batch)
95 94
96 self.accelerator.backward(loss) 95 self.accelerator.backward(loss)
97 96
98 if self.accelerator.sync_gradients: 97 if self.accelerator.sync_gradients:
99 self.on_clip() 98 self.on_clip()
100 99
101 self.optimizer.step() 100 self.optimizer.step()
102 lr_scheduler.step() 101 lr_scheduler.step()
103 self.optimizer.zero_grad(set_to_none=True) 102 self.optimizer.zero_grad(set_to_none=True)
104 103
105 if self.accelerator.sync_gradients: 104 if self.accelerator.sync_gradients:
106 progress_bar.update(1) 105 progress_bar.update(1)
107 106
108 self.model.eval() 107 self.model.eval()
109 self.on_eval()
110 108
111 with torch.inference_mode(): 109 with torch.inference_mode():
112 for step, batch in enumerate(self.val_dataloader): 110 with self.on_eval():
113 if step >= num_val_batches: 111 for step, batch in enumerate(self.val_dataloader):
114 break 112 if step >= num_val_batches:
113 break
115 114
116 loss, acc, bsz = self.loss_fn(step, batch, True) 115 loss, acc, bsz = self.loss_fn(step, batch, True)
117 avg_loss.update(loss.detach_(), bsz) 116 avg_loss.update(loss.detach_(), bsz)
118 avg_acc.update(acc.detach_(), bsz) 117 avg_acc.update(acc.detach_(), bsz)
119 118
120 progress_bar.update(1) 119 progress_bar.update(1)
121 120
122 loss = avg_loss.avg.item() 121 loss = avg_loss.avg.item()
123 acc = avg_acc.avg.item() 122 acc = avg_acc.avg.item()
diff --git a/training/util.py b/training/util.py
index 6f1e85a..bed7111 100644
--- a/training/util.py
+++ b/training/util.py
@@ -262,7 +262,7 @@ class EMAModel:
262 raise ValueError("collected_params and shadow_params must have the same length") 262 raise ValueError("collected_params and shadow_params must have the same length")
263 263
264 @contextmanager 264 @contextmanager
265 def apply_temporary(self, parameters): 265 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
266 try: 266 try:
267 parameters = list(parameters) 267 parameters = list(parameters)
268 original_params = [p.clone() for p in parameters] 268 original_params = [p.clone() for p in parameters]