diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 30 |
1 files changed, 5 insertions, 25 deletions
diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from functools import partial | 5 | from functools import partial |
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | import math | 7 | import math |
@@ -307,26 +308,6 @@ def parse_args(): | |||
307 | help="Minimum learning rate in the lr scheduler." | 308 | help="Minimum learning rate in the lr scheduler." |
308 | ) | 309 | ) |
309 | parser.add_argument( | 310 | parser.add_argument( |
310 | "--use_ema", | ||
311 | action="store_true", | ||
312 | help="Whether to use EMA model." | ||
313 | ) | ||
314 | parser.add_argument( | ||
315 | "--ema_inv_gamma", | ||
316 | type=float, | ||
317 | default=1.0 | ||
318 | ) | ||
319 | parser.add_argument( | ||
320 | "--ema_power", | ||
321 | type=float, | ||
322 | default=4/5 | ||
323 | ) | ||
324 | parser.add_argument( | ||
325 | "--ema_max_decay", | ||
326 | type=float, | ||
327 | default=0.9999 | ||
328 | ) | ||
329 | parser.add_argument( | ||
330 | "--optimizer", | 311 | "--optimizer", |
331 | type=str, | 312 | type=str, |
332 | default="dadan", | 313 | default="dadan", |
@@ -715,10 +696,6 @@ def main(): | |||
715 | sample_scheduler=sample_scheduler, | 696 | sample_scheduler=sample_scheduler, |
716 | checkpoint_output_dir=checkpoint_output_dir, | 697 | checkpoint_output_dir=checkpoint_output_dir, |
717 | gradient_checkpointing=args.gradient_checkpointing, | 698 | gradient_checkpointing=args.gradient_checkpointing, |
718 | use_ema=args.use_ema, | ||
719 | ema_inv_gamma=args.ema_inv_gamma, | ||
720 | ema_power=args.ema_power, | ||
721 | ema_max_decay=args.ema_max_decay, | ||
722 | sample_batch_size=args.sample_batch_size, | 699 | sample_batch_size=args.sample_batch_size, |
723 | sample_num_batches=args.sample_batches, | 700 | sample_num_batches=args.sample_batches, |
724 | sample_num_steps=args.sample_steps, | 701 | sample_num_steps=args.sample_steps, |
@@ -780,7 +757,10 @@ def main(): | |||
780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
781 | 758 | ||
782 | optimizer = create_optimizer( | 759 | optimizer = create_optimizer( |
783 | text_encoder.text_model.embeddings.overlay.parameters(), | 760 | itertools.chain( |
761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
763 | ), | ||
784 | lr=args.learning_rate, | 764 | lr=args.learning_rate, |
785 | ) | 765 | ) |
786 | 766 | ||