summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 07:27:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 07:27:45 +0100
commit3c6ccadd3c12c54a1fa2280bce505a2dd511958a (patch)
tree019b9ac09acc85196ef1d09e2d968ba917ac8993 /train_ti.py
parentAdded Dreambooth strategy (diff)
downloadtextual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.tar.gz
textual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.tar.bz2
textual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.zip
Implemented extended Dreambooth training
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py62
1 files changed, 22 insertions, 40 deletions
diff --git a/train_ti.py b/train_ti.py
index 2497519..48a2333 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed
13from slugify import slugify 13from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, VlpnDataItem 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
18from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 19from training.optimization import get_scheduler
@@ -446,15 +446,15 @@ def parse_args():
446 if isinstance(args.placeholder_tokens, str): 446 if isinstance(args.placeholder_tokens, str):
447 args.placeholder_tokens = [args.placeholder_tokens] 447 args.placeholder_tokens = [args.placeholder_tokens]
448 448
449 if len(args.placeholder_tokens) == 0:
450 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
451
452 if isinstance(args.initializer_tokens, str): 449 if isinstance(args.initializer_tokens, str):
453 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 450 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
454 451
455 if len(args.initializer_tokens) == 0: 452 if len(args.initializer_tokens) == 0:
456 raise ValueError("You must specify --initializer_tokens") 453 raise ValueError("You must specify --initializer_tokens")
457 454
455 if len(args.placeholder_tokens) == 0:
456 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
457
458 if len(args.placeholder_tokens) != len(args.initializer_tokens): 458 if len(args.placeholder_tokens) != len(args.initializer_tokens):
459 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 459 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
460 460
@@ -544,9 +544,6 @@ def main():
544 args.train_batch_size * accelerator.num_processes 544 args.train_batch_size * accelerator.num_processes
545 ) 545 )
546 546
547 if args.find_lr:
548 args.learning_rate = 1e-5
549
550 if args.use_8bit_adam: 547 if args.use_8bit_adam:
551 try: 548 try:
552 import bitsandbytes as bnb 549 import bitsandbytes as bnb
@@ -563,19 +560,6 @@ def main():
563 elif args.mixed_precision == "bf16": 560 elif args.mixed_precision == "bf16":
564 weight_dtype = torch.bfloat16 561 weight_dtype = torch.bfloat16
565 562
566 def keyword_filter(item: VlpnDataItem):
567 cond1 = any(
568 keyword in part
569 for keyword in args.placeholder_tokens
570 for part in item.prompt
571 )
572 cond3 = args.collection is None or args.collection in item.collection
573 cond4 = args.exclude_collections is None or not any(
574 collection in item.collection
575 for collection in args.exclude_collections
576 )
577 return cond1 and cond3 and cond4
578
579 datamodule = VlpnDataModule( 563 datamodule = VlpnDataModule(
580 data_file=args.train_data_file, 564 data_file=args.train_data_file,
581 batch_size=args.train_batch_size, 565 batch_size=args.train_batch_size,
@@ -593,7 +577,7 @@ def main():
593 valid_set_size=args.valid_set_size, 577 valid_set_size=args.valid_set_size,
594 valid_set_repeat=args.valid_set_repeat, 578 valid_set_repeat=args.valid_set_repeat,
595 seed=args.seed, 579 seed=args.seed,
596 filter=keyword_filter, 580 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
597 dtype=weight_dtype 581 dtype=weight_dtype
598 ) 582 )
599 datamodule.setup() 583 datamodule.setup()
@@ -622,8 +606,6 @@ def main():
622 text_encoder=text_encoder, 606 text_encoder=text_encoder,
623 vae=vae, 607 vae=vae,
624 noise_scheduler=noise_scheduler, 608 noise_scheduler=noise_scheduler,
625 train_dataloader=train_dataloader,
626 val_dataloader=val_dataloader,
627 dtype=weight_dtype, 609 dtype=weight_dtype,
628 seed=args.seed, 610 seed=args.seed,
629 callbacks_fn=textual_inversion_strategy 611 callbacks_fn=textual_inversion_strategy
@@ -638,25 +620,25 @@ def main():
638 amsgrad=args.adam_amsgrad, 620 amsgrad=args.adam_amsgrad,
639 ) 621 )
640 622
641 if args.find_lr: 623 lr_scheduler = get_scheduler(
642 lr_scheduler = None 624 args.lr_scheduler,
643 else: 625 optimizer=optimizer,
644 lr_scheduler = get_scheduler( 626 num_training_steps_per_epoch=len(train_dataloader),
645 args.lr_scheduler, 627 gradient_accumulation_steps=args.gradient_accumulation_steps,
646 optimizer=optimizer, 628 min_lr=args.lr_min_lr,
647 num_training_steps_per_epoch=len(train_dataloader), 629 warmup_func=args.lr_warmup_func,
648 gradient_accumulation_steps=args.gradient_accumulation_steps, 630 annealing_func=args.lr_annealing_func,
649 min_lr=args.lr_min_lr, 631 warmup_exp=args.lr_warmup_exp,
650 warmup_func=args.lr_warmup_func, 632 annealing_exp=args.lr_annealing_exp,
651 annealing_func=args.lr_annealing_func, 633 cycles=args.lr_cycles,
652 warmup_exp=args.lr_warmup_exp, 634 train_epochs=args.num_train_epochs,
653 annealing_exp=args.lr_annealing_exp, 635 warmup_epochs=args.lr_warmup_epochs,
654 cycles=args.lr_cycles, 636 )
655 train_epochs=args.num_train_epochs,
656 warmup_epochs=args.lr_warmup_epochs,
657 )
658 637
659 trainer( 638 trainer(
639 project="textual_inversion",
640 train_dataloader=train_dataloader,
641 val_dataloader=val_dataloader,
660 optimizer=optimizer, 642 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 643 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs, 644 num_train_epochs=args.num_train_epochs,