summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-01 17:33:00 +0200
committerVolpeon <git@volpeon.ink>2023-04-01 17:33:00 +0200
commit86e908656bcd7585ec45cd930176800f759f146a (patch)
tree1169e9b1728e4c6fc8b70e46a37080ae0794ada8 /train_ti.py
parentExperimental: TI via LoRA (diff)
downloadtextual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.gz
textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.tar.bz2
textual-inversion-diff-86e908656bcd7585ec45cd930176800f759f146a.zip
Combined TI with embedding and LoRA
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py30
1 files changed, 5 insertions, 25 deletions
diff --git a/train_ti.py b/train_ti.py
index 0ce0056..26ac384 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,6 +1,7 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
4from functools import partial 5from functools import partial
5from pathlib import Path 6from pathlib import Path
6import math 7import math
@@ -307,26 +308,6 @@ def parse_args():
307 help="Minimum learning rate in the lr scheduler." 308 help="Minimum learning rate in the lr scheduler."
308 ) 309 )
309 parser.add_argument( 310 parser.add_argument(
310 "--use_ema",
311 action="store_true",
312 help="Whether to use EMA model."
313 )
314 parser.add_argument(
315 "--ema_inv_gamma",
316 type=float,
317 default=1.0
318 )
319 parser.add_argument(
320 "--ema_power",
321 type=float,
322 default=4/5
323 )
324 parser.add_argument(
325 "--ema_max_decay",
326 type=float,
327 default=0.9999
328 )
329 parser.add_argument(
330 "--optimizer", 311 "--optimizer",
331 type=str, 312 type=str,
332 default="dadan", 313 default="dadan",
@@ -715,10 +696,6 @@ def main():
715 sample_scheduler=sample_scheduler, 696 sample_scheduler=sample_scheduler,
716 checkpoint_output_dir=checkpoint_output_dir, 697 checkpoint_output_dir=checkpoint_output_dir,
717 gradient_checkpointing=args.gradient_checkpointing, 698 gradient_checkpointing=args.gradient_checkpointing,
718 use_ema=args.use_ema,
719 ema_inv_gamma=args.ema_inv_gamma,
720 ema_power=args.ema_power,
721 ema_max_decay=args.ema_max_decay,
722 sample_batch_size=args.sample_batch_size, 699 sample_batch_size=args.sample_batch_size,
723 sample_num_batches=args.sample_batches, 700 sample_num_batches=args.sample_batches,
724 sample_num_steps=args.sample_steps, 701 sample_num_steps=args.sample_steps,
@@ -780,7 +757,10 @@ def main():
780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 757 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
781 758
782 optimizer = create_optimizer( 759 optimizer = create_optimizer(
783 text_encoder.text_model.embeddings.overlay.parameters(), 760 itertools.chain(
761 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
762 text_encoder.text_model.embeddings.overlay.parameters(),
763 ),
784 lr=args.learning_rate, 764 lr=args.learning_rate,
785 ) 765 )
786 766