diff options
-rw-r--r-- | train_ti.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index dc36e42..2f13128 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -270,7 +270,6 @@ def parse_args(): | |||
270 | parser.add_argument( | 270 | parser.add_argument( |
271 | "--use_ema", | 271 | "--use_ema", |
272 | action="store_true", | 272 | action="store_true", |
273 | default=True, | ||
274 | help="Whether to use EMA model." | 273 | help="Whether to use EMA model." |
275 | ) | 274 | ) |
276 | parser.add_argument( | 275 | parser.add_argument( |
@@ -1004,6 +1003,8 @@ def main(): | |||
1004 | "train/cur_acc": acc.item(), | 1003 | "train/cur_acc": acc.item(), |
1005 | "lr": lr_scheduler.get_last_lr()[0], | 1004 | "lr": lr_scheduler.get_last_lr()[0], |
1006 | } | 1005 | } |
1006 | if args.use_ema: | ||
1007 | logs["ema_decay"] = ema_embeddings.decay | ||
1007 | 1008 | ||
1008 | accelerator.log(logs, step=global_step) | 1009 | accelerator.log(logs, step=global_step) |
1009 | 1010 | ||