From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- train_ti.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler -from training.util import save_args +from training.util import AverageMeter, save_args logger = get_logger(__name__) @@ -920,6 +920,11 @@ def main(): lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles + avg_loss = AverageMeter() + avg_acc = AverageMeter() + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) @@ -977,7 +982,7 @@ def main(): mid_point=args.lr_mid_point, ) - checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" + checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" trainer( train_dataloader=datamodule.train_dataloader, @@ -994,6 +999,10 @@ def main(): sample_frequency=sample_frequency, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, ) training_iter += 1 -- cgit v1.2.3-54-g00ecf