From 208e48134e324e934ad964bdc61880cc923f4c0d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 22:13:55 +0200 Subject: Revert --- train_ti.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 26ac384..5482326 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,7 +1,6 @@ import argparse import datetime import logging -import itertools from functools import partial from pathlib import Path import math @@ -307,6 +306,26 @@ def parse_args(): default=0.04, help="Minimum learning rate in the lr scheduler." ) + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=4/5 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) parser.add_argument( "--optimizer", type=str, @@ -334,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( @@ -431,6 +450,23 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay", + default=1e2, + type=float, + help="Embedding decay factor." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -696,6 +732,13 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay=args.emb_decay, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, @@ -757,10 +800,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - itertools.chain( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - text_encoder.text_model.embeddings.overlay.parameters(), - ), + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf