summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 22:58:31 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 22:58:31 +0100
commitc3e59b3f075cbb549b21a905f03269d3d29fb47b (patch)
tree3af3fedd22d3aff5d8037c4d3a67035d0049c74c /train_ti.py
parentAdded EMA to TI (diff)
downloadtextual-inversion-diff-c3e59b3f075cbb549b21a905f03269d3d29fb47b.tar.gz
textual-inversion-diff-c3e59b3f075cbb549b21a905f03269d3d29fb47b.tar.bz2
textual-inversion-diff-c3e59b3f075cbb549b21a905f03269d3d29fb47b.zip
Log EMA decay
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index dc36e42..2f13128 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -270,7 +270,6 @@ def parse_args():
270 parser.add_argument( 270 parser.add_argument(
271 "--use_ema", 271 "--use_ema",
272 action="store_true", 272 action="store_true",
273 default=True,
274 help="Whether to use EMA model." 273 help="Whether to use EMA model."
275 ) 274 )
276 parser.add_argument( 275 parser.add_argument(
@@ -1004,6 +1003,8 @@ def main():
1004 "train/cur_acc": acc.item(), 1003 "train/cur_acc": acc.item(),
1005 "lr": lr_scheduler.get_last_lr()[0], 1004 "lr": lr_scheduler.get_last_lr()[0],
1006 } 1005 }
1006 if args.use_ema:
1007 logs["ema_decay"] = ema_embeddings.decay
1007 1008
1008 accelerator.log(logs, step=global_step) 1009 accelerator.log(logs, step=global_step)
1009 1010