diff options
| -rw-r--r-- | train_ti.py | 120 | ||||
| -rw-r--r-- | training/lr.py | 51 | ||||
| -rw-r--r-- | training/util.py | 2 |
3 files changed, 92 insertions, 81 deletions
diff --git a/train_ti.py b/train_ti.py index aa2bf02..f622299 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -2,10 +2,9 @@ import argparse | |||
| 2 | import math | 2 | import math |
| 3 | import datetime | 3 | import datetime |
| 4 | import logging | 4 | import logging |
| 5 | import copy | ||
| 6 | from pathlib import Path | ||
| 7 | from functools import partial | 5 | from functools import partial |
| 8 | from contextlib import nullcontext | 6 | from pathlib import Path |
| 7 | from contextlib import contextmanager, nullcontext | ||
| 9 | 8 | ||
| 10 | import torch | 9 | import torch |
| 11 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| @@ -849,11 +848,24 @@ def main(): | |||
| 849 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 848 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 850 | val_steps = num_val_steps_per_epoch * num_epochs | 849 | val_steps = num_val_steps_per_epoch * num_epochs |
| 851 | 850 | ||
| 851 | @contextmanager | ||
| 852 | def on_train(): | 852 | def on_train(): |
| 853 | tokenizer.train() | 853 | try: |
| 854 | tokenizer.train() | ||
| 855 | yield | ||
| 856 | finally: | ||
| 857 | tokenizer.eval() | ||
| 854 | 858 | ||
| 859 | @contextmanager | ||
| 855 | def on_eval(): | 860 | def on_eval(): |
| 856 | tokenizer.eval() | 861 | try: |
| 862 | ema_context = ema_embeddings.apply_temporary( | ||
| 863 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() | ||
| 864 | |||
| 865 | with ema_context: | ||
| 866 | yield | ||
| 867 | finally: | ||
| 868 | pass | ||
| 857 | 869 | ||
| 858 | loop = partial( | 870 | loop = partial( |
| 859 | run_model, | 871 | run_model, |
| @@ -961,80 +973,80 @@ def main(): | |||
| 961 | local_progress_bar.reset() | 973 | local_progress_bar.reset() |
| 962 | 974 | ||
| 963 | text_encoder.train() | 975 | text_encoder.train() |
| 964 | on_train() | ||
| 965 | 976 | ||
| 966 | for step, batch in enumerate(train_dataloader): | 977 | with on_train(): |
| 967 | with accelerator.accumulate(text_encoder): | 978 | for step, batch in enumerate(train_dataloader): |
| 968 | loss, acc, bsz = loop(step, batch) | 979 | with accelerator.accumulate(text_encoder): |
| 980 | loss, acc, bsz = loop(step, batch) | ||
| 969 | 981 | ||
| 970 | accelerator.backward(loss) | 982 | accelerator.backward(loss) |
| 971 | 983 | ||
| 972 | optimizer.step() | 984 | optimizer.step() |
| 973 | if not accelerator.optimizer_step_was_skipped: | 985 | if not accelerator.optimizer_step_was_skipped: |
| 974 | lr_scheduler.step() | 986 | lr_scheduler.step() |
| 975 | optimizer.zero_grad(set_to_none=True) | 987 | optimizer.zero_grad(set_to_none=True) |
| 976 | 988 | ||
| 977 | avg_loss.update(loss.detach_(), bsz) | 989 | avg_loss.update(loss.detach_(), bsz) |
| 978 | avg_acc.update(acc.detach_(), bsz) | 990 | avg_acc.update(acc.detach_(), bsz) |
| 979 | 991 | ||
| 980 | # Checks if the accelerator has performed an optimization step behind the scenes | 992 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 981 | if accelerator.sync_gradients: | 993 | if accelerator.sync_gradients: |
| 982 | if args.use_ema: | 994 | if args.use_ema: |
| 983 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 995 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
| 984 | 996 | ||
| 985 | local_progress_bar.update(1) | 997 | local_progress_bar.update(1) |
| 986 | global_progress_bar.update(1) | 998 | global_progress_bar.update(1) |
| 987 | 999 | ||
| 988 | global_step += 1 | 1000 | global_step += 1 |
| 989 | 1001 | ||
| 990 | logs = { | 1002 | logs = { |
| 991 | "train/loss": avg_loss.avg.item(), | 1003 | "train/loss": avg_loss.avg.item(), |
| 992 | "train/acc": avg_acc.avg.item(), | 1004 | "train/acc": avg_acc.avg.item(), |
| 993 | "train/cur_loss": loss.item(), | 1005 | "train/cur_loss": loss.item(), |
| 994 | "train/cur_acc": acc.item(), | 1006 | "train/cur_acc": acc.item(), |
| 995 | "lr": lr_scheduler.get_last_lr()[0], | 1007 | "lr": lr_scheduler.get_last_lr()[0], |
| 996 | } | 1008 | } |
| 997 | if args.use_ema: | 1009 | if args.use_ema: |
| 998 | logs["ema_decay"] = ema_embeddings.decay | 1010 | logs["ema_decay"] = ema_embeddings.decay |
| 999 | 1011 | ||
| 1000 | accelerator.log(logs, step=global_step) | 1012 | accelerator.log(logs, step=global_step) |
| 1001 | 1013 | ||
| 1002 | local_progress_bar.set_postfix(**logs) | 1014 | local_progress_bar.set_postfix(**logs) |
| 1003 | 1015 | ||
| 1004 | if global_step >= args.max_train_steps: | 1016 | if global_step >= args.max_train_steps: |
| 1005 | break | 1017 | break |
| 1006 | 1018 | ||
| 1007 | accelerator.wait_for_everyone() | 1019 | accelerator.wait_for_everyone() |
| 1008 | 1020 | ||
| 1009 | text_encoder.eval() | 1021 | text_encoder.eval() |
| 1010 | on_eval() | ||
| 1011 | 1022 | ||
| 1012 | cur_loss_val = AverageMeter() | 1023 | cur_loss_val = AverageMeter() |
| 1013 | cur_acc_val = AverageMeter() | 1024 | cur_acc_val = AverageMeter() |
| 1014 | 1025 | ||
| 1015 | with torch.inference_mode(): | 1026 | with torch.inference_mode(): |
| 1016 | for step, batch in enumerate(val_dataloader): | 1027 | with on_eval(): |
| 1017 | loss, acc, bsz = loop(step, batch, True) | 1028 | for step, batch in enumerate(val_dataloader): |
| 1029 | loss, acc, bsz = loop(step, batch, True) | ||
| 1018 | 1030 | ||
| 1019 | loss = loss.detach_() | 1031 | loss = loss.detach_() |
| 1020 | acc = acc.detach_() | 1032 | acc = acc.detach_() |
| 1021 | 1033 | ||
| 1022 | cur_loss_val.update(loss, bsz) | 1034 | cur_loss_val.update(loss, bsz) |
| 1023 | cur_acc_val.update(acc, bsz) | 1035 | cur_acc_val.update(acc, bsz) |
| 1024 | 1036 | ||
| 1025 | avg_loss_val.update(loss, bsz) | 1037 | avg_loss_val.update(loss, bsz) |
| 1026 | avg_acc_val.update(acc, bsz) | 1038 | avg_acc_val.update(acc, bsz) |
| 1027 | 1039 | ||
| 1028 | local_progress_bar.update(1) | 1040 | local_progress_bar.update(1) |
| 1029 | global_progress_bar.update(1) | 1041 | global_progress_bar.update(1) |
| 1030 | 1042 | ||
| 1031 | logs = { | 1043 | logs = { |
| 1032 | "val/loss": avg_loss_val.avg.item(), | 1044 | "val/loss": avg_loss_val.avg.item(), |
| 1033 | "val/acc": avg_acc_val.avg.item(), | 1045 | "val/acc": avg_acc_val.avg.item(), |
| 1034 | "val/cur_loss": loss.item(), | 1046 | "val/cur_loss": loss.item(), |
| 1035 | "val/cur_acc": acc.item(), | 1047 | "val/cur_acc": acc.item(), |
| 1036 | } | 1048 | } |
| 1037 | local_progress_bar.set_postfix(**logs) | 1049 | local_progress_bar.set_postfix(**logs) |
| 1038 | 1050 | ||
| 1039 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 1051 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 1040 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 1052 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
diff --git a/training/lr.py b/training/lr.py index c765150..68e0f72 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -1,5 +1,5 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import copy | 2 | from contextlib import _GeneratorContextManager, nullcontext |
| 3 | from typing import Callable, Any, Tuple, Union | 3 | from typing import Callable, Any, Tuple, Union |
| 4 | from functools import partial | 4 | from functools import partial |
| 5 | 5 | ||
| @@ -25,9 +25,9 @@ class LRFinder(): | |||
| 25 | train_dataloader, | 25 | train_dataloader, |
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, |
| 29 | on_clip: Callable[[], None] = noop, | 29 | on_clip: Callable[[], None] = noop, |
| 30 | on_eval: Callable[[], None] = noop | 30 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext |
| 31 | ): | 31 | ): |
| 32 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
| 33 | self.model = model | 33 | self.model = model |
| @@ -51,7 +51,6 @@ class LRFinder(): | |||
| 51 | num_train_batches: int = 1, | 51 | num_train_batches: int = 1, |
| 52 | num_val_batches: int = math.inf, | 52 | num_val_batches: int = math.inf, |
| 53 | smooth_f: float = 0.05, | 53 | smooth_f: float = 0.05, |
| 54 | diverge_th: int = 5, | ||
| 55 | ): | 54 | ): |
| 56 | best_loss = None | 55 | best_loss = None |
| 57 | best_acc = None | 56 | best_acc = None |
| @@ -84,40 +83,40 @@ class LRFinder(): | |||
| 84 | avg_acc = AverageMeter() | 83 | avg_acc = AverageMeter() |
| 85 | 84 | ||
| 86 | self.model.train() | 85 | self.model.train() |
| 87 | self.on_train() | ||
| 88 | 86 | ||
| 89 | for step, batch in enumerate(self.train_dataloader): | 87 | with self.on_train(): |
| 90 | if step >= num_train_batches: | 88 | for step, batch in enumerate(self.train_dataloader): |
| 91 | break | 89 | if step >= num_train_batches: |
| 90 | break | ||
| 92 | 91 | ||
| 93 | with self.accelerator.accumulate(self.model): | 92 | with self.accelerator.accumulate(self.model): |
| 94 | loss, acc, bsz = self.loss_fn(step, batch) | 93 | loss, acc, bsz = self.loss_fn(step, batch) |
| 95 | 94 | ||
| 96 | self.accelerator.backward(loss) | 95 | self.accelerator.backward(loss) |
| 97 | 96 | ||
| 98 | if self.accelerator.sync_gradients: | 97 | if self.accelerator.sync_gradients: |
| 99 | self.on_clip() | 98 | self.on_clip() |
| 100 | 99 | ||
| 101 | self.optimizer.step() | 100 | self.optimizer.step() |
| 102 | lr_scheduler.step() | 101 | lr_scheduler.step() |
| 103 | self.optimizer.zero_grad(set_to_none=True) | 102 | self.optimizer.zero_grad(set_to_none=True) |
| 104 | 103 | ||
| 105 | if self.accelerator.sync_gradients: | 104 | if self.accelerator.sync_gradients: |
| 106 | progress_bar.update(1) | 105 | progress_bar.update(1) |
| 107 | 106 | ||
| 108 | self.model.eval() | 107 | self.model.eval() |
| 109 | self.on_eval() | ||
| 110 | 108 | ||
| 111 | with torch.inference_mode(): | 109 | with torch.inference_mode(): |
| 112 | for step, batch in enumerate(self.val_dataloader): | 110 | with self.on_eval(): |
| 113 | if step >= num_val_batches: | 111 | for step, batch in enumerate(self.val_dataloader): |
| 114 | break | 112 | if step >= num_val_batches: |
| 113 | break | ||
| 115 | 114 | ||
| 116 | loss, acc, bsz = self.loss_fn(step, batch, True) | 115 | loss, acc, bsz = self.loss_fn(step, batch, True) |
| 117 | avg_loss.update(loss.detach_(), bsz) | 116 | avg_loss.update(loss.detach_(), bsz) |
| 118 | avg_acc.update(acc.detach_(), bsz) | 117 | avg_acc.update(acc.detach_(), bsz) |
| 119 | 118 | ||
| 120 | progress_bar.update(1) | 119 | progress_bar.update(1) |
| 121 | 120 | ||
| 122 | loss = avg_loss.avg.item() | 121 | loss = avg_loss.avg.item() |
| 123 | acc = avg_acc.avg.item() | 122 | acc = avg_acc.avg.item() |
diff --git a/training/util.py b/training/util.py index 6f1e85a..bed7111 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -262,7 +262,7 @@ class EMAModel: | |||
| 262 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
| 263 | 263 | ||
| 264 | @contextmanager | 264 | @contextmanager |
| 265 | def apply_temporary(self, parameters): | 265 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
| 266 | try: | 266 | try: |
| 267 | parameters = list(parameters) | 267 | parameters = list(parameters) |
| 268 | original_params = [p.clone() for p in parameters] | 268 | original_params = [p.clone() for p in parameters] |
