summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py19
-rw-r--r--train_dreambooth.py484
-rw-r--r--train_ti.py62
-rw-r--r--training/functional.py7
4 files changed, 200 insertions, 372 deletions
diff --git a/data/csv.py b/data/csv.py
index 2a8115b..2b1e202 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -130,6 +130,25 @@ class VlpnDataItem(NamedTuple):
130 collection: list[str] 130 collection: list[str]
131 131
132 132
133def keyword_filter(
134 placeholder_tokens: Optional[list[str]],
135 collection: Optional[list[str]],
136 exclude_collections: Optional[list[str]],
137 item: VlpnDataItem
138):
139 cond1 = placeholder_tokens is None or any(
140 keyword in part
141 for keyword in placeholder_tokens
142 for part in item.prompt
143 )
144 cond2 = collection is None or collection in item.collection
145 cond3 = exclude_collections is None or not any(
146 collection in item.collection
147 for collection in exclude_collections
148 )
149 return cond1 and cond2 and cond3
150
151
133class VlpnDataModule(): 152class VlpnDataModule():
134 def __init__( 153 def __init__(
135 self, 154 self,
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 71bad7e..944256c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1,10 +1,8 @@
1import argparse 1import argparse
2import itertools
3import datetime 2import datetime
4import logging 3import logging
5from pathlib import Path 4from pathlib import Path
6from functools import partial 5from functools import partial
7from contextlib import contextmanager, nullcontext
8 6
9import torch 7import torch
10import torch.utils.checkpoint 8import torch.utils.checkpoint
@@ -12,18 +10,15 @@ import torch.utils.checkpoint
12from accelerate import Accelerator 10from accelerate import Accelerator
13from accelerate.logging import get_logger 11from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 12from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, UNet2DConditionModel
16import matplotlib.pyplot as plt
17from transformers import CLIPTextModel
18from slugify import slugify 13from slugify import slugify
19 14
20from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 16from data.csv import VlpnDataModule, keyword_filter
22from data.csv import VlpnDataModule, VlpnDataItem 17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
18from training.strategy.ti import textual_inversion_strategy
19from training.strategy.dreambooth import dreambooth_strategy
23from training.optimization import get_scheduler 20from training.optimization import get_scheduler
24from training.lr import LRFinder 21from training.util import save_args
25from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models
26from models.clip.tokenizer import MultiCLIPTokenizer
27 22
28logger = get_logger(__name__) 23logger = get_logger(__name__)
29 24
@@ -73,7 +68,7 @@ def parse_args():
73 help="A token to use as a placeholder for the concept.", 68 help="A token to use as a placeholder for the concept.",
74 ) 69 )
75 parser.add_argument( 70 parser.add_argument(
76 "--initializer_token", 71 "--initializer_tokens",
77 type=str, 72 type=str,
78 nargs='*', 73 nargs='*',
79 default=[], 74 default=[],
@@ -151,7 +146,7 @@ def parse_args():
151 parser.add_argument( 146 parser.add_argument(
152 "--num_class_images", 147 "--num_class_images",
153 type=int, 148 type=int,
154 default=1, 149 default=0,
155 help="How many class images to generate." 150 help="How many class images to generate."
156 ) 151 )
157 parser.add_argument( 152 parser.add_argument(
@@ -437,23 +432,23 @@ def parse_args():
437 if isinstance(args.placeholder_tokens, str): 432 if isinstance(args.placeholder_tokens, str):
438 args.placeholder_tokens = [args.placeholder_tokens] 433 args.placeholder_tokens = [args.placeholder_tokens]
439 434
440 if len(args.placeholder_tokens) == 0: 435 if isinstance(args.initializer_tokens, str):
441 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] 436 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
442 437
443 if isinstance(args.initializer_token, str): 438 if len(args.initializer_tokens) == 0:
444 args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) 439 raise ValueError("You must specify --initializer_tokens")
445 440
446 if len(args.initializer_token) == 0: 441 if len(args.placeholder_tokens) == 0:
447 raise ValueError("You must specify --initializer_token") 442 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
448 443
449 if len(args.placeholder_tokens) != len(args.initializer_token): 444 if len(args.placeholder_tokens) != len(args.initializer_tokens):
450 raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") 445 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
451 446
452 if args.num_vectors is None: 447 if args.num_vectors is None:
453 args.num_vectors = 1 448 args.num_vectors = 1
454 449
455 if isinstance(args.num_vectors, int): 450 if isinstance(args.num_vectors, int):
456 args.num_vectors = [args.num_vectors] * len(args.initializer_token) 451 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
457 452
458 if len(args.placeholder_tokens) != len(args.num_vectors): 453 if len(args.placeholder_tokens) != len(args.num_vectors):
459 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 454 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
@@ -470,102 +465,9 @@ def parse_args():
470 return args 465 return args
471 466
472 467
473class Checkpointer(CheckpointerBase):
474 def __init__(
475 self,
476 weight_dtype: torch.dtype,
477 accelerator: Accelerator,
478 vae: AutoencoderKL,
479 unet: UNet2DConditionModel,
480 ema_unet: EMAModel,
481 tokenizer: MultiCLIPTokenizer,
482 text_encoder: CLIPTextModel,
483 scheduler,
484 *args,
485 **kwargs
486 ):
487 super().__init__(*args, **kwargs)
488
489 self.weight_dtype = weight_dtype
490 self.accelerator = accelerator
491 self.vae = vae
492 self.unet = unet
493 self.ema_unet = ema_unet
494 self.tokenizer = tokenizer
495 self.text_encoder = text_encoder
496 self.scheduler = scheduler
497
498 @torch.no_grad()
499 def save_model(self):
500 print("Saving model...")
501
502 unet = self.accelerator.unwrap_model(self.unet)
503 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
504
505 ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext()
506
507 with ema_context:
508 pipeline = VlpnStableDiffusion(
509 text_encoder=text_encoder,
510 vae=self.vae,
511 unet=unet,
512 tokenizer=self.tokenizer,
513 scheduler=self.scheduler,
514 )
515 pipeline.save_pretrained(self.output_dir.joinpath("model"))
516
517 del unet
518 del text_encoder
519 del pipeline
520
521 if torch.cuda.is_available():
522 torch.cuda.empty_cache()
523
524 @torch.no_grad()
525 def save_samples(self, step):
526 unet = self.accelerator.unwrap_model(self.unet)
527 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
528
529 ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext()
530
531 with ema_context:
532 orig_unet_dtype = unet.dtype
533 orig_text_encoder_dtype = text_encoder.dtype
534
535 unet.to(dtype=self.weight_dtype)
536 text_encoder.to(dtype=self.weight_dtype)
537
538 pipeline = VlpnStableDiffusion(
539 text_encoder=text_encoder,
540 vae=self.vae,
541 unet=unet,
542 tokenizer=self.tokenizer,
543 scheduler=self.scheduler,
544 ).to(self.accelerator.device)
545 pipeline.set_progress_bar_config(dynamic_ncols=True)
546
547 super().save_samples(pipeline, step)
548
549 unet.to(dtype=orig_unet_dtype)
550 text_encoder.to(dtype=orig_text_encoder_dtype)
551
552 del unet
553 del text_encoder
554 del pipeline
555
556 if torch.cuda.is_available():
557 torch.cuda.empty_cache()
558
559
560def main(): 468def main():
561 args = parse_args() 469 args = parse_args()
562 470
563 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
564 raise ValueError(
565 "Gradient accumulation is not supported when training the text encoder in distributed training. "
566 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
567 )
568
569 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 471 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
570 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 472 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
571 output_dir.mkdir(parents=True, exist_ok=True) 473 output_dir.mkdir(parents=True, exist_ok=True)
@@ -621,41 +523,12 @@ def main():
621 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) 523 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
622 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") 524 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
623 525
624 if args.use_ema:
625 ema_unet = EMAModel(
626 unet.parameters(),
627 inv_gamma=args.ema_inv_gamma,
628 power=args.ema_power,
629 max_value=args.ema_max_decay,
630 )
631 else:
632 ema_unet = None
633
634 vae.requires_grad_(False)
635
636 if args.train_text_encoder:
637 print(f"Training entire text encoder.")
638
639 embeddings.persist()
640 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
641 else:
642 print(f"Training added text embeddings")
643
644 text_encoder.text_model.encoder.requires_grad_(False)
645 text_encoder.text_model.final_layer_norm.requires_grad_(False)
646 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
647 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
648
649 if args.scale_lr: 526 if args.scale_lr:
650 args.learning_rate = ( 527 args.learning_rate = (
651 args.learning_rate * args.gradient_accumulation_steps * 528 args.learning_rate * args.gradient_accumulation_steps *
652 args.train_batch_size * accelerator.num_processes 529 args.train_batch_size * accelerator.num_processes
653 ) 530 )
654 531
655 if args.find_lr:
656 args.learning_rate = 1e-6
657
658 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
659 if args.use_8bit_adam: 532 if args.use_8bit_adam:
660 try: 533 try:
661 import bitsandbytes as bnb 534 import bitsandbytes as bnb
@@ -666,41 +539,30 @@ def main():
666 else: 539 else:
667 optimizer_class = torch.optim.AdamW 540 optimizer_class = torch.optim.AdamW
668 541
669 if args.train_text_encoder:
670 text_encoder_params_to_optimize = text_encoder.parameters()
671 else:
672 text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters()
673
674 # Initialize the optimizer
675 optimizer = optimizer_class(
676 [
677 {
678 'params': unet.parameters(),
679 },
680 {
681 'params': text_encoder_params_to_optimize,
682 }
683 ],
684 lr=args.learning_rate,
685 betas=(args.adam_beta1, args.adam_beta2),
686 weight_decay=args.adam_weight_decay,
687 eps=args.adam_epsilon,
688 amsgrad=args.adam_amsgrad,
689 )
690
691 weight_dtype = torch.float32 542 weight_dtype = torch.float32
692 if args.mixed_precision == "fp16": 543 if args.mixed_precision == "fp16":
693 weight_dtype = torch.float16 544 weight_dtype = torch.float16
694 elif args.mixed_precision == "bf16": 545 elif args.mixed_precision == "bf16":
695 weight_dtype = torch.bfloat16 546 weight_dtype = torch.bfloat16
696 547
697 def keyword_filter(item: VlpnDataItem): 548 trainer = partial(
698 cond3 = args.collection is None or args.collection in item.collection 549 train,
699 cond4 = args.exclude_collections is None or not any( 550 accelerator=accelerator,
700 collection in item.collection 551 unet=unet,
701 for collection in args.exclude_collections 552 text_encoder=text_encoder,
702 ) 553 vae=vae,
703 return cond3 and cond4 554 noise_scheduler=noise_scheduler,
555 dtype=weight_dtype,
556 seed=args.seed,
557 callbacks_fn=textual_inversion_strategy
558 )
559
560 # Initial TI
561
562 print("Phase 1: Textual Inversion")
563
564 cur_dir = output_dir.joinpath("1-ti")
565 cur_dir.mkdir(parents=True, exist_ok=True)
704 566
705 datamodule = VlpnDataModule( 567 datamodule = VlpnDataModule(
706 data_file=args.train_data_file, 568 data_file=args.train_data_file,
@@ -709,182 +571,146 @@ def main():
709 class_subdir=args.class_image_dir, 571 class_subdir=args.class_image_dir,
710 num_class_images=args.num_class_images, 572 num_class_images=args.num_class_images,
711 size=args.resolution, 573 size=args.resolution,
712 num_buckets=args.num_buckets,
713 progressive_buckets=args.progressive_buckets,
714 bucket_step_size=args.bucket_step_size,
715 bucket_max_pixels=args.bucket_max_pixels,
716 dropout=args.tag_dropout,
717 shuffle=not args.no_tag_shuffle, 574 shuffle=not args.no_tag_shuffle,
718 template_key=args.train_data_template, 575 template_key=args.train_data_template,
719 valid_set_size=args.valid_set_size, 576 valid_set_size=args.valid_set_size,
720 valid_set_repeat=args.valid_set_repeat, 577 valid_set_repeat=args.valid_set_repeat,
721 seed=args.seed, 578 seed=args.seed,
722 filter=keyword_filter, 579 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
723 dtype=weight_dtype 580 dtype=weight_dtype
724 ) 581 )
725 datamodule.setup() 582 datamodule.setup()
726 583
727 train_dataloader = datamodule.train_dataloader 584 optimizer = optimizer_class(
728 val_dataloader = datamodule.val_dataloader 585 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
729 586 lr=2e-1,
730 if args.num_class_images != 0: 587 weight_decay=0.0,
731 generate_class_images( 588 )
732 accelerator,
733 text_encoder,
734 vae,
735 unet,
736 tokenizer,
737 sample_scheduler,
738 datamodule.data_train,
739 args.sample_batch_size,
740 args.sample_image_size,
741 args.sample_steps
742 )
743
744 if args.find_lr:
745 lr_scheduler = None
746 else:
747 lr_scheduler = get_scheduler(
748 args.lr_scheduler,
749 optimizer=optimizer,
750 min_lr=args.lr_min_lr,
751 warmup_func=args.lr_warmup_func,
752 annealing_func=args.lr_annealing_func,
753 warmup_exp=args.lr_warmup_exp,
754 annealing_exp=args.lr_annealing_exp,
755 cycles=args.lr_cycles,
756 train_epochs=args.num_train_epochs,
757 warmup_epochs=args.lr_warmup_epochs,
758 num_training_steps_per_epoch=len(train_dataloader),
759 gradient_accumulation_steps=args.gradient_accumulation_steps
760 )
761 589
762 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 590 lr_scheduler = get_scheduler(
763 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 591 "linear",
592 optimizer=optimizer,
593 num_training_steps_per_epoch=len(datamodule.train_dataloader),
594 gradient_accumulation_steps=args.gradient_accumulation_steps,
595 train_epochs=30,
596 warmup_epochs=10,
597 )
598
599 trainer(
600 project="textual_inversion",
601 train_dataloader=datamodule.train_dataloader,
602 val_dataloader=datamodule.val_dataloader,
603 optimizer=optimizer,
604 lr_scheduler=lr_scheduler,
605 num_train_epochs=30,
606 sample_frequency=5,
607 checkpoint_frequency=9999999,
608 with_prior_preservation=args.num_class_images != 0,
609 prior_loss_weight=args.prior_loss_weight,
610 # --
611 tokenizer=tokenizer,
612 sample_scheduler=sample_scheduler,
613 output_dir=cur_dir,
614 placeholder_tokens=args.placeholder_tokens,
615 placeholder_token_ids=placeholder_token_ids,
616 learning_rate=2e-1,
617 gradient_checkpointing=args.gradient_checkpointing,
618 use_emb_decay=True,
619 sample_batch_size=args.sample_batch_size,
620 sample_num_batches=args.sample_batches,
621 sample_num_steps=args.sample_steps,
622 sample_image_size=args.sample_image_size,
764 ) 623 )
765 624
766 vae.to(accelerator.device, dtype=weight_dtype) 625 # Dreambooth
767 626
768 if args.use_ema: 627 print("Phase 2: Dreambooth")
769 ema_unet.to(accelerator.device)
770 628
771 @contextmanager 629 cur_dir = output_dir.joinpath("2db")
772 def on_train(epoch: int): 630 cur_dir.mkdir(parents=True, exist_ok=True)
773 try:
774 tokenizer.train()
775 631
776 if epoch < args.train_text_encoder_epochs: 632 args.seed = (args.seed + 28635) >> 32
777 text_encoder.train()
778 elif epoch == args.train_text_encoder_epochs:
779 text_encoder.requires_grad_(False)
780 633
781 yield 634 datamodule = VlpnDataModule(
782 finally: 635 data_file=args.train_data_file,
783 pass 636 batch_size=args.train_batch_size,
637 tokenizer=tokenizer,
638 class_subdir=args.class_image_dir,
639 num_class_images=args.num_class_images,
640 size=args.resolution,
641 num_buckets=args.num_buckets,
642 progressive_buckets=args.progressive_buckets,
643 bucket_step_size=args.bucket_step_size,
644 bucket_max_pixels=args.bucket_max_pixels,
645 dropout=args.tag_dropout,
646 shuffle=not args.no_tag_shuffle,
647 template_key=args.train_data_template,
648 valid_set_size=args.valid_set_size,
649 valid_set_repeat=args.valid_set_repeat,
650 seed=args.seed,
651 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
652 dtype=weight_dtype
653 )
654 datamodule.setup()
784 655
785 @contextmanager 656 optimizer = optimizer_class(
786 def on_eval(): 657 [
787 try: 658 {
788 tokenizer.eval() 659 'params': unet.parameters(),
789 text_encoder.eval() 660 },
790 661 {
791 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() 662 'params': text_encoder.parameters(),
792 663 }
793 with ema_context: 664 ],
794 yield 665 lr=args.learning_rate,
795 finally: 666 betas=(args.adam_beta1, args.adam_beta2),
796 pass 667 weight_decay=args.adam_weight_decay,
797 668 eps=args.adam_epsilon,
798 def on_before_optimize(epoch: int): 669 amsgrad=args.adam_amsgrad,
799 if accelerator.sync_gradients: 670 )
800 params_to_clip = [unet.parameters()] 671
801 if args.train_text_encoder and epoch < args.train_text_encoder_epochs: 672 lr_scheduler = get_scheduler(
802 params_to_clip.append(text_encoder.parameters()) 673 args.lr_scheduler,
803 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) 674 optimizer=optimizer,
804 675 num_training_steps_per_epoch=len(datamodule.train_dataloader),
805 @torch.no_grad() 676 gradient_accumulation_steps=args.gradient_accumulation_steps,
806 def on_after_optimize(lr: float): 677 min_lr=args.lr_min_lr,
807 if not args.train_text_encoder: 678 warmup_func=args.lr_warmup_func,
808 text_encoder.text_model.embeddings.normalize( 679 annealing_func=args.lr_annealing_func,
809 args.decay_target, 680 warmup_exp=args.lr_warmup_exp,
810 min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) 681 annealing_exp=args.lr_annealing_exp,
811 ) 682 cycles=args.lr_cycles,
812 683 train_epochs=args.num_train_epochs,
813 def on_log(): 684 warmup_epochs=args.lr_warmup_epochs,
814 if args.use_ema: 685 )
815 return {"ema_decay": ema_unet.decay} 686
816 return {} 687 trainer(
817 688 project="dreambooth",
818 loss_step_ = partial( 689 train_dataloader=datamodule.train_dataloader,
819 loss_step, 690 val_dataloader=datamodule.val_dataloader,
820 vae, 691 optimizer=optimizer,
821 noise_scheduler, 692 lr_scheduler=lr_scheduler,
822 unet, 693 num_train_epochs=args.num_train_epochs,
823 text_encoder, 694 sample_frequency=args.sample_frequency,
824 args.prior_loss_weight, 695 checkpoint_frequency=args.checkpoint_frequency,
825 args.seed, 696 with_prior_preservation=args.num_class_images != 0,
826 ) 697 prior_loss_weight=args.prior_loss_weight,
827 698 # --
828 checkpointer = Checkpointer(
829 weight_dtype=weight_dtype,
830 train_dataloader=train_dataloader,
831 val_dataloader=val_dataloader,
832 accelerator=accelerator,
833 vae=vae,
834 unet=unet,
835 ema_unet=ema_unet,
836 tokenizer=tokenizer, 699 tokenizer=tokenizer,
837 text_encoder=text_encoder, 700 sample_scheduler=sample_scheduler,
838 scheduler=sample_scheduler, 701 output_dir=cur_dir,
839 placeholder_tokens=args.placeholder_tokens, 702 gradient_checkpointing=args.gradient_checkpointing,
840 placeholder_token_ids=placeholder_token_ids, 703 train_text_encoder_epochs=args.train_text_encoder_epochs,
841 output_dir=output_dir, 704 max_grad_norm=args.max_grad_norm,
842 sample_steps=args.sample_steps, 705 use_ema=args.use_ema,
843 sample_image_size=args.sample_image_size, 706 ema_inv_gamma=args.ema_inv_gamma,
707 ema_power=args.ema_power,
708 ema_max_decay=args.ema_max_decay,
844 sample_batch_size=args.sample_batch_size, 709 sample_batch_size=args.sample_batch_size,
845 sample_batches=args.sample_batches, 710 sample_num_batches=args.sample_batches,
846 seed=args.seed 711 sample_num_steps=args.sample_steps,
847 ) 712 sample_image_size=args.sample_image_size,
848 713 )
849 if accelerator.is_main_process:
850 accelerator.init_trackers("dreambooth", config=config)
851
852 if args.find_lr:
853 lr_finder = LRFinder(
854 accelerator=accelerator,
855 optimizer=optimizer,
856 model=unet,
857 train_dataloader=train_dataloader,
858 val_dataloader=val_dataloader,
859 loss_step=loss_step_,
860 on_train=on_train,
861 on_eval=on_eval,
862 on_before_optimize=on_before_optimize,
863 on_after_optimize=on_after_optimize,
864 )
865 lr_finder.run(num_epochs=100, end_lr=1e2)
866
867 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
868 plt.close()
869 else:
870 train_loop(
871 accelerator=accelerator,
872 optimizer=optimizer,
873 lr_scheduler=lr_scheduler,
874 model=unet,
875 checkpointer=checkpointer,
876 train_dataloader=train_dataloader,
877 val_dataloader=val_dataloader,
878 loss_step=loss_step_,
879 sample_frequency=args.sample_frequency,
880 checkpoint_frequency=args.checkpoint_frequency,
881 global_step_offset=0,
882 num_epochs=args.num_train_epochs,
883 on_log=on_log,
884 on_train=on_train,
885 on_after_optimize=on_after_optimize,
886 on_eval=on_eval
887 )
888 714
889 715
890if __name__ == "__main__": 716if __name__ == "__main__":
diff --git a/train_ti.py b/train_ti.py
index 2497519..48a2333 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed
13from slugify import slugify 13from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, VlpnDataItem 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
18from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 19from training.optimization import get_scheduler
@@ -446,15 +446,15 @@ def parse_args():
446 if isinstance(args.placeholder_tokens, str): 446 if isinstance(args.placeholder_tokens, str):
447 args.placeholder_tokens = [args.placeholder_tokens] 447 args.placeholder_tokens = [args.placeholder_tokens]
448 448
449 if len(args.placeholder_tokens) == 0:
450 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
451
452 if isinstance(args.initializer_tokens, str): 449 if isinstance(args.initializer_tokens, str):
453 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 450 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
454 451
455 if len(args.initializer_tokens) == 0: 452 if len(args.initializer_tokens) == 0:
456 raise ValueError("You must specify --initializer_tokens") 453 raise ValueError("You must specify --initializer_tokens")
457 454
455 if len(args.placeholder_tokens) == 0:
456 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
457
458 if len(args.placeholder_tokens) != len(args.initializer_tokens): 458 if len(args.placeholder_tokens) != len(args.initializer_tokens):
459 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 459 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
460 460
@@ -544,9 +544,6 @@ def main():
544 args.train_batch_size * accelerator.num_processes 544 args.train_batch_size * accelerator.num_processes
545 ) 545 )
546 546
547 if args.find_lr:
548 args.learning_rate = 1e-5
549
550 if args.use_8bit_adam: 547 if args.use_8bit_adam:
551 try: 548 try:
552 import bitsandbytes as bnb 549 import bitsandbytes as bnb
@@ -563,19 +560,6 @@ def main():
563 elif args.mixed_precision == "bf16": 560 elif args.mixed_precision == "bf16":
564 weight_dtype = torch.bfloat16 561 weight_dtype = torch.bfloat16
565 562
566 def keyword_filter(item: VlpnDataItem):
567 cond1 = any(
568 keyword in part
569 for keyword in args.placeholder_tokens
570 for part in item.prompt
571 )
572 cond3 = args.collection is None or args.collection in item.collection
573 cond4 = args.exclude_collections is None or not any(
574 collection in item.collection
575 for collection in args.exclude_collections
576 )
577 return cond1 and cond3 and cond4
578
579 datamodule = VlpnDataModule( 563 datamodule = VlpnDataModule(
580 data_file=args.train_data_file, 564 data_file=args.train_data_file,
581 batch_size=args.train_batch_size, 565 batch_size=args.train_batch_size,
@@ -593,7 +577,7 @@ def main():
593 valid_set_size=args.valid_set_size, 577 valid_set_size=args.valid_set_size,
594 valid_set_repeat=args.valid_set_repeat, 578 valid_set_repeat=args.valid_set_repeat,
595 seed=args.seed, 579 seed=args.seed,
596 filter=keyword_filter, 580 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
597 dtype=weight_dtype 581 dtype=weight_dtype
598 ) 582 )
599 datamodule.setup() 583 datamodule.setup()
@@ -622,8 +606,6 @@ def main():
622 text_encoder=text_encoder, 606 text_encoder=text_encoder,
623 vae=vae, 607 vae=vae,
624 noise_scheduler=noise_scheduler, 608 noise_scheduler=noise_scheduler,
625 train_dataloader=train_dataloader,
626 val_dataloader=val_dataloader,
627 dtype=weight_dtype, 609 dtype=weight_dtype,
628 seed=args.seed, 610 seed=args.seed,
629 callbacks_fn=textual_inversion_strategy 611 callbacks_fn=textual_inversion_strategy
@@ -638,25 +620,25 @@ def main():
638 amsgrad=args.adam_amsgrad, 620 amsgrad=args.adam_amsgrad,
639 ) 621 )
640 622
641 if args.find_lr: 623 lr_scheduler = get_scheduler(
642 lr_scheduler = None 624 args.lr_scheduler,
643 else: 625 optimizer=optimizer,
644 lr_scheduler = get_scheduler( 626 num_training_steps_per_epoch=len(train_dataloader),
645 args.lr_scheduler, 627 gradient_accumulation_steps=args.gradient_accumulation_steps,
646 optimizer=optimizer, 628 min_lr=args.lr_min_lr,
647 num_training_steps_per_epoch=len(train_dataloader), 629 warmup_func=args.lr_warmup_func,
648 gradient_accumulation_steps=args.gradient_accumulation_steps, 630 annealing_func=args.lr_annealing_func,
649 min_lr=args.lr_min_lr, 631 warmup_exp=args.lr_warmup_exp,
650 warmup_func=args.lr_warmup_func, 632 annealing_exp=args.lr_annealing_exp,
651 annealing_func=args.lr_annealing_func, 633 cycles=args.lr_cycles,
652 warmup_exp=args.lr_warmup_exp, 634 train_epochs=args.num_train_epochs,
653 annealing_exp=args.lr_annealing_exp, 635 warmup_epochs=args.lr_warmup_epochs,
654 cycles=args.lr_cycles, 636 )
655 train_epochs=args.num_train_epochs,
656 warmup_epochs=args.lr_warmup_epochs,
657 )
658 637
659 trainer( 638 trainer(
639 project="textual_inversion",
640 train_dataloader=train_dataloader,
641 val_dataloader=val_dataloader,
660 optimizer=optimizer, 642 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 643 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs, 644 num_train_epochs=args.num_train_epochs,
diff --git a/training/functional.py b/training/functional.py
index 5984ffb..f5c111e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -494,10 +494,11 @@ def train(
494 text_encoder: CLIPTextModel, 494 text_encoder: CLIPTextModel,
495 vae: AutoencoderKL, 495 vae: AutoencoderKL,
496 noise_scheduler: DDPMScheduler, 496 noise_scheduler: DDPMScheduler,
497 train_dataloader: DataLoader,
498 val_dataloader: DataLoader,
499 dtype: torch.dtype, 497 dtype: torch.dtype,
500 seed: int, 498 seed: int,
499 project: str,
500 train_dataloader: DataLoader,
501 val_dataloader: DataLoader,
501 optimizer: torch.optim.Optimizer, 502 optimizer: torch.optim.Optimizer,
502 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 503 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
503 callbacks_fn: Callable[..., TrainingCallbacks], 504 callbacks_fn: Callable[..., TrainingCallbacks],
@@ -544,7 +545,7 @@ def train(
544 ) 545 )
545 546
546 if accelerator.is_main_process: 547 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion") 548 accelerator.init_trackers(project)
548 549
549 train_loop( 550 train_loop(
550 accelerator=accelerator, 551 accelerator=accelerator,