summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
committerVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
commit179a45253a5b3712f32bd127f693a6bb810a9c17 (patch)
treeac9f1152d858089742e4f9ce79e0870e0f2b9a2d /train_ti.py
parentFix TI (diff)
downloadtextual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.gz
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.bz2
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.zip
Support num_train_steps arg again
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/train_ti.py b/train_ti.py
index e4fd464..7bcc72f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,6 +3,7 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6import math
6 7
7import torch 8import torch
8import torch.utils.checkpoint 9import torch.utils.checkpoint
@@ -207,7 +208,12 @@ def parse_args():
207 parser.add_argument( 208 parser.add_argument(
208 "--num_train_epochs", 209 "--num_train_epochs",
209 type=int, 210 type=int,
210 default=100 211 default=None
212 )
213 parser.add_argument(
214 "--num_train_steps",
215 type=int,
216 default=2000
211 ) 217 )
212 parser.add_argument( 218 parser.add_argument(
213 "--gradient_accumulation_steps", 219 "--gradient_accumulation_steps",
@@ -513,13 +519,13 @@ def parse_args():
513 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: 519 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
514 raise ValueError("--alias_tokens must be a list with an even number of items") 520 raise ValueError("--alias_tokens must be a list with an even number of items")
515 521
516 args.alias_tokens += [
517 item
518 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
519 for item in pair
520 ]
521
522 if args.sequential: 522 if args.sequential:
523 args.alias_tokens += [
524 item
525 for pair in zip(args.placeholder_tokens, args.initializer_tokens)
526 for item in pair
527 ]
528
523 if isinstance(args.train_data_template, str): 529 if isinstance(args.train_data_template, str):
524 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 530 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
525 531
@@ -607,6 +613,7 @@ def main():
607 raise ValueError("--embeddings_dir must point to an existing directory") 613 raise ValueError("--embeddings_dir must point to an existing directory")
608 614
609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 615 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
616 embeddings.persist()
610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 617 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
611 618
612 if args.scale_lr: 619 if args.scale_lr:
@@ -682,7 +689,6 @@ def main():
682 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 689 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
683 no_val=args.valid_set_size == 0, 690 no_val=args.valid_set_size == 0,
684 strategy=textual_inversion_strategy, 691 strategy=textual_inversion_strategy,
685 num_train_epochs=args.num_train_epochs,
686 gradient_accumulation_steps=args.gradient_accumulation_steps, 692 gradient_accumulation_steps=args.gradient_accumulation_steps,
687 sample_frequency=args.sample_frequency, 693 sample_frequency=args.sample_frequency,
688 checkpoint_frequency=args.checkpoint_frequency, 694 checkpoint_frequency=args.checkpoint_frequency,
@@ -752,6 +758,11 @@ def main():
752 ) 758 )
753 datamodule.setup() 759 datamodule.setup()
754 760
761 num_train_epochs = args.num_train_epochs
762
763 if num_train_epochs is None:
764 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
765
755 optimizer = create_optimizer( 766 optimizer = create_optimizer(
756 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 767 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
757 lr=args.learning_rate, 768 lr=args.learning_rate,
@@ -769,7 +780,7 @@ def main():
769 annealing_exp=args.lr_annealing_exp, 780 annealing_exp=args.lr_annealing_exp,
770 cycles=args.lr_cycles, 781 cycles=args.lr_cycles,
771 end_lr=1e3, 782 end_lr=1e3,
772 train_epochs=args.num_train_epochs, 783 train_epochs=num_train_epochs,
773 warmup_epochs=args.lr_warmup_epochs, 784 warmup_epochs=args.lr_warmup_epochs,
774 ) 785 )
775 786
@@ -779,6 +790,7 @@ def main():
779 val_dataloader=datamodule.val_dataloader, 790 val_dataloader=datamodule.val_dataloader,
780 optimizer=optimizer, 791 optimizer=optimizer,
781 lr_scheduler=lr_scheduler, 792 lr_scheduler=lr_scheduler,
793 num_train_epochs=num_train_epochs,
782 # -- 794 # --
783 sample_output_dir=sample_output_dir, 795 sample_output_dir=sample_output_dir,
784 placeholder_tokens=placeholder_tokens, 796 placeholder_tokens=placeholder_tokens,