From 86e908656bcd7585ec45cd930176800f759f146a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Apr 2023 17:33:00 +0200 Subject: Combined TI with embedding and LoRA --- train_ti.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,6 +1,7 @@ import argparse import datetime import logging +import itertools from functools import partial from pathlib import Path import math @@ -306,26 +307,6 @@ def parse_args(): default=0.04, help="Minimum learning rate in the lr scheduler." ) - parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=4/5 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) parser.add_argument( "--optimizer", type=str, @@ -715,10 +696,6 @@ def main(): sample_scheduler=sample_scheduler, checkpoint_output_dir=checkpoint_output_dir, gradient_checkpointing=args.gradient_checkpointing, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, @@ -780,7 +757,10 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.overlay.parameters(), + itertools.chain( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + text_encoder.text_model.embeddings.overlay.parameters(), + ), lr=args.learning_rate, ) -- cgit v1.2.3-54-g00ecf