summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-27 07:47:59 +0200
committerVolpeon <git@volpeon.ink>2023-04-27 07:47:59 +0200
commit6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch)
tree6c65817b9351453bfb5366f7010f8d87659c0dd0 /train_ti.py
parentFix cycle loop (diff)
downloadtextual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz
textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2
textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index d1e5467..fce4a5e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.ti import textual_inversion_strategy 24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.util import save_args 26from training.util import AverageMeter, save_args
27 27
28logger = get_logger(__name__) 28logger = get_logger(__name__)
29 29
@@ -920,6 +920,11 @@ def main():
920 lr_warmup_epochs = args.lr_warmup_epochs 920 lr_warmup_epochs = args.lr_warmup_epochs
921 lr_cycles = args.lr_cycles 921 lr_cycles = args.lr_cycles
922 922
923 avg_loss = AverageMeter()
924 avg_acc = AverageMeter()
925 avg_loss_val = AverageMeter()
926 avg_acc_val = AverageMeter()
927
923 while True: 928 while True:
924 if len(auto_cycles) != 0: 929 if len(auto_cycles) != 0:
925 response = auto_cycles.pop(0) 930 response = auto_cycles.pop(0)
@@ -977,7 +982,7 @@ def main():
977 mid_point=args.lr_mid_point, 982 mid_point=args.lr_mid_point,
978 ) 983 )
979 984
980 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" 985 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}"
981 986
982 trainer( 987 trainer(
983 train_dataloader=datamodule.train_dataloader, 988 train_dataloader=datamodule.train_dataloader,
@@ -994,6 +999,10 @@ def main():
994 sample_frequency=sample_frequency, 999 sample_frequency=sample_frequency,
995 placeholder_tokens=placeholder_tokens, 1000 placeholder_tokens=placeholder_tokens,
996 placeholder_token_ids=placeholder_token_ids, 1001 placeholder_token_ids=placeholder_token_ids,
1002 avg_loss=avg_loss,
1003 avg_acc=avg_acc,
1004 avg_loss_val=avg_loss_val,
1005 avg_acc_val=avg_acc_val,
997 ) 1006 )
998 1007
999 training_iter += 1 1008 training_iter += 1