diff options
author | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
commit | 6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch) | |
tree | 6c65817b9351453bfb5366f7010f8d87659c0dd0 /train_ti.py | |
parent | Fix cycle loop (diff) | |
download | textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2 textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
27 | 27 | ||
28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
29 | 29 | ||
@@ -920,6 +920,11 @@ def main(): | |||
920 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
921 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
922 | 922 | ||
923 | avg_loss = AverageMeter() | ||
924 | avg_acc = AverageMeter() | ||
925 | avg_loss_val = AverageMeter() | ||
926 | avg_acc_val = AverageMeter() | ||
927 | |||
923 | while True: | 928 | while True: |
924 | if len(auto_cycles) != 0: | 929 | if len(auto_cycles) != 0: |
925 | response = auto_cycles.pop(0) | 930 | response = auto_cycles.pop(0) |
@@ -977,7 +982,7 @@ def main(): | |||
977 | mid_point=args.lr_mid_point, | 982 | mid_point=args.lr_mid_point, |
978 | ) | 983 | ) |
979 | 984 | ||
980 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" | 985 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" |
981 | 986 | ||
982 | trainer( | 987 | trainer( |
983 | train_dataloader=datamodule.train_dataloader, | 988 | train_dataloader=datamodule.train_dataloader, |
@@ -994,6 +999,10 @@ def main(): | |||
994 | sample_frequency=sample_frequency, | 999 | sample_frequency=sample_frequency, |
995 | placeholder_tokens=placeholder_tokens, | 1000 | placeholder_tokens=placeholder_tokens, |
996 | placeholder_token_ids=placeholder_token_ids, | 1001 | placeholder_token_ids=placeholder_token_ids, |
1002 | avg_loss=avg_loss, | ||
1003 | avg_acc=avg_acc, | ||
1004 | avg_loss_val=avg_loss_val, | ||
1005 | avg_acc_val=avg_acc_val, | ||
997 | ) | 1006 | ) |
998 | 1007 | ||
999 | training_iter += 1 | 1008 | training_iter += 1 |