diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 52 |
1 files changed, 46 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 26ac384..5482326 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,7 +1,6 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
5 | from functools import partial | 4 | from functools import partial |
6 | from pathlib import Path | 5 | from pathlib import Path |
7 | import math | 6 | import math |
@@ -308,6 +307,26 @@ def parse_args(): | |||
308 | help="Minimum learning rate in the lr scheduler." | 307 | help="Minimum learning rate in the lr scheduler." |
309 | ) | 308 | ) |
310 | parser.add_argument( | 309 | parser.add_argument( |
310 | "--use_ema", | ||
311 | action="store_true", | ||
312 | help="Whether to use EMA model." | ||
313 | ) | ||
314 | parser.add_argument( | ||
315 | "--ema_inv_gamma", | ||
316 | type=float, | ||
317 | default=1.0 | ||
318 | ) | ||
319 | parser.add_argument( | ||
320 | "--ema_power", | ||
321 | type=float, | ||
322 | default=4/5 | ||
323 | ) | ||
324 | parser.add_argument( | ||
325 | "--ema_max_decay", | ||
326 | type=float, | ||
327 | default=0.9999 | ||
328 | ) | ||
329 | parser.add_argument( | ||
311 | "--optimizer", | 330 | "--optimizer", |
312 | type=str, | 331 | type=str, |
313 | default="dadan", | 332 | default="dadan", |
@@ -334,7 +353,7 @@ def parse_args(): | |||
334 | parser.add_argument( | 353 | parser.add_argument( |
335 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
336 | type=float, | 355 | type=float, |
337 | default=1e-2, | 356 | default=0, |
338 | help="Weight decay to use." | 357 | help="Weight decay to use." |
339 | ) | 358 | ) |
340 | parser.add_argument( | 359 | parser.add_argument( |
@@ -432,6 +451,23 @@ def parse_args(): | |||
432 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
433 | ) | 452 | ) |
434 | parser.add_argument( | 453 | parser.add_argument( |
454 | "--use_emb_decay", | ||
455 | action="store_true", | ||
456 | help="Whether to use embedding decay." | ||
457 | ) | ||
458 | parser.add_argument( | ||
459 | "--emb_decay_target", | ||
460 | default=0.4, | ||
461 | type=float, | ||
462 | help="Embedding decay target." | ||
463 | ) | ||
464 | parser.add_argument( | ||
465 | "--emb_decay", | ||
466 | default=1e2, | ||
467 | type=float, | ||
468 | help="Embedding decay factor." | ||
469 | ) | ||
470 | parser.add_argument( | ||
435 | "--noise_timesteps", | 471 | "--noise_timesteps", |
436 | type=int, | 472 | type=int, |
437 | default=1000, | 473 | default=1000, |
@@ -696,6 +732,13 @@ def main(): | |||
696 | sample_scheduler=sample_scheduler, | 732 | sample_scheduler=sample_scheduler, |
697 | checkpoint_output_dir=checkpoint_output_dir, | 733 | checkpoint_output_dir=checkpoint_output_dir, |
698 | gradient_checkpointing=args.gradient_checkpointing, | 734 | gradient_checkpointing=args.gradient_checkpointing, |
735 | use_emb_decay=args.use_emb_decay, | ||
736 | emb_decay_target=args.emb_decay_target, | ||
737 | emb_decay=args.emb_decay, | ||
738 | use_ema=args.use_ema, | ||
739 | ema_inv_gamma=args.ema_inv_gamma, | ||
740 | ema_power=args.ema_power, | ||
741 | ema_max_decay=args.ema_max_decay, | ||
699 | sample_batch_size=args.sample_batch_size, | 742 | sample_batch_size=args.sample_batch_size, |
700 | sample_num_batches=args.sample_batches, | 743 | sample_num_batches=args.sample_batches, |
701 | sample_num_steps=args.sample_steps, | 744 | sample_num_steps=args.sample_steps, |
@@ -757,10 +800,7 @@ def main(): | |||
757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 800 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
758 | 801 | ||
759 | optimizer = create_optimizer( | 802 | optimizer = create_optimizer( |
760 | itertools.chain( | 803 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
763 | ), | ||
764 | lr=args.learning_rate, | 804 | lr=args.learning_rate, |
765 | ) | 805 | ) |
766 | 806 | ||