diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
| commit | 6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch) | |
| tree | 6c65817b9351453bfb5366f7010f8d87659c0dd0 /train_ti.py | |
| parent | Fix cycle loop (diff) | |
| download | textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2 textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| 24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
| 25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
| 27 | 27 | ||
| 28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
| 29 | 29 | ||
| @@ -920,6 +920,11 @@ def main(): | |||
| 920 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
| 921 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
| 922 | 922 | ||
| 923 | avg_loss = AverageMeter() | ||
| 924 | avg_acc = AverageMeter() | ||
| 925 | avg_loss_val = AverageMeter() | ||
| 926 | avg_acc_val = AverageMeter() | ||
| 927 | |||
| 923 | while True: | 928 | while True: |
| 924 | if len(auto_cycles) != 0: | 929 | if len(auto_cycles) != 0: |
| 925 | response = auto_cycles.pop(0) | 930 | response = auto_cycles.pop(0) |
| @@ -977,7 +982,7 @@ def main(): | |||
| 977 | mid_point=args.lr_mid_point, | 982 | mid_point=args.lr_mid_point, |
| 978 | ) | 983 | ) |
| 979 | 984 | ||
| 980 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" | 985 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" |
| 981 | 986 | ||
| 982 | trainer( | 987 | trainer( |
| 983 | train_dataloader=datamodule.train_dataloader, | 988 | train_dataloader=datamodule.train_dataloader, |
| @@ -994,6 +999,10 @@ def main(): | |||
| 994 | sample_frequency=sample_frequency, | 999 | sample_frequency=sample_frequency, |
| 995 | placeholder_tokens=placeholder_tokens, | 1000 | placeholder_tokens=placeholder_tokens, |
| 996 | placeholder_token_ids=placeholder_token_ids, | 1001 | placeholder_token_ids=placeholder_token_ids, |
| 1002 | avg_loss=avg_loss, | ||
| 1003 | avg_acc=avg_acc, | ||
| 1004 | avg_loss_val=avg_loss_val, | ||
| 1005 | avg_acc_val=avg_acc_val, | ||
| 997 | ) | 1006 | ) |
| 998 | 1007 | ||
| 999 | training_iter += 1 | 1008 | training_iter += 1 |
