From c96073646bbb638d7d78fdd7d9fdeed08d1454b5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 16:30:36 +0200 Subject: Experimental: TI via LoRA --- train_ti.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 5482326..0ce0056 100644 --- a/train_ti.py +++ b/train_ti.py @@ -353,7 +353,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -450,23 +450,6 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) - parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." - ) - parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." - ) - parser.add_argument( - "--emb_decay", - default=1e2, - type=float, - help="Embedding decay factor." - ) parser.add_argument( "--noise_timesteps", type=int, @@ -732,9 +715,6 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, @@ -800,7 +780,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf