diff options
-rw-r--r-- | data/csv.py | 19 | ||||
-rw-r--r-- | train_dreambooth.py | 484 | ||||
-rw-r--r-- | train_ti.py | 62 | ||||
-rw-r--r-- | training/functional.py | 7 |
4 files changed, 200 insertions, 372 deletions
diff --git a/data/csv.py b/data/csv.py index 2a8115b..2b1e202 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -130,6 +130,25 @@ class VlpnDataItem(NamedTuple): | |||
130 | collection: list[str] | 130 | collection: list[str] |
131 | 131 | ||
132 | 132 | ||
133 | def keyword_filter( | ||
134 | placeholder_tokens: Optional[list[str]], | ||
135 | collection: Optional[list[str]], | ||
136 | exclude_collections: Optional[list[str]], | ||
137 | item: VlpnDataItem | ||
138 | ): | ||
139 | cond1 = placeholder_tokens is None or any( | ||
140 | keyword in part | ||
141 | for keyword in placeholder_tokens | ||
142 | for part in item.prompt | ||
143 | ) | ||
144 | cond2 = collection is None or collection in item.collection | ||
145 | cond3 = exclude_collections is None or not any( | ||
146 | collection in item.collection | ||
147 | for collection in exclude_collections | ||
148 | ) | ||
149 | return cond1 and cond2 and cond3 | ||
150 | |||
151 | |||
133 | class VlpnDataModule(): | 152 | class VlpnDataModule(): |
134 | def __init__( | 153 | def __init__( |
135 | self, | 154 | self, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 71bad7e..944256c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -1,10 +1,8 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | ||
3 | import datetime | 2 | import datetime |
4 | import logging | 3 | import logging |
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | from functools import partial | 5 | from functools import partial |
7 | from contextlib import contextmanager, nullcontext | ||
8 | 6 | ||
9 | import torch | 7 | import torch |
10 | import torch.utils.checkpoint | 8 | import torch.utils.checkpoint |
@@ -12,18 +10,15 @@ import torch.utils.checkpoint | |||
12 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
15 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
16 | import matplotlib.pyplot as plt | ||
17 | from transformers import CLIPTextModel | ||
18 | from slugify import slugify | 13 | from slugify import slugify |
19 | 14 | ||
20 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 16 | from data.csv import VlpnDataModule, keyword_filter |
22 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
18 | from training.strategy.ti import textual_inversion_strategy | ||
19 | from training.strategy.dreambooth import dreambooth_strategy | ||
23 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
24 | from training.lr import LRFinder | 21 | from training.util import save_args |
25 | from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models | ||
26 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
27 | 22 | ||
28 | logger = get_logger(__name__) | 23 | logger = get_logger(__name__) |
29 | 24 | ||
@@ -73,7 +68,7 @@ def parse_args(): | |||
73 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
74 | ) | 69 | ) |
75 | parser.add_argument( | 70 | parser.add_argument( |
76 | "--initializer_token", | 71 | "--initializer_tokens", |
77 | type=str, | 72 | type=str, |
78 | nargs='*', | 73 | nargs='*', |
79 | default=[], | 74 | default=[], |
@@ -151,7 +146,7 @@ def parse_args(): | |||
151 | parser.add_argument( | 146 | parser.add_argument( |
152 | "--num_class_images", | 147 | "--num_class_images", |
153 | type=int, | 148 | type=int, |
154 | default=1, | 149 | default=0, |
155 | help="How many class images to generate." | 150 | help="How many class images to generate." |
156 | ) | 151 | ) |
157 | parser.add_argument( | 152 | parser.add_argument( |
@@ -437,23 +432,23 @@ def parse_args(): | |||
437 | if isinstance(args.placeholder_tokens, str): | 432 | if isinstance(args.placeholder_tokens, str): |
438 | args.placeholder_tokens = [args.placeholder_tokens] | 433 | args.placeholder_tokens = [args.placeholder_tokens] |
439 | 434 | ||
440 | if len(args.placeholder_tokens) == 0: | 435 | if isinstance(args.initializer_tokens, str): |
441 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] | 436 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
442 | 437 | ||
443 | if isinstance(args.initializer_token, str): | 438 | if len(args.initializer_tokens) == 0: |
444 | args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) | 439 | raise ValueError("You must specify --initializer_tokens") |
445 | 440 | ||
446 | if len(args.initializer_token) == 0: | 441 | if len(args.placeholder_tokens) == 0: |
447 | raise ValueError("You must specify --initializer_token") | 442 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
448 | 443 | ||
449 | if len(args.placeholder_tokens) != len(args.initializer_token): | 444 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
450 | raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") | 445 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
451 | 446 | ||
452 | if args.num_vectors is None: | 447 | if args.num_vectors is None: |
453 | args.num_vectors = 1 | 448 | args.num_vectors = 1 |
454 | 449 | ||
455 | if isinstance(args.num_vectors, int): | 450 | if isinstance(args.num_vectors, int): |
456 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 451 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) |
457 | 452 | ||
458 | if len(args.placeholder_tokens) != len(args.num_vectors): | 453 | if len(args.placeholder_tokens) != len(args.num_vectors): |
459 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 454 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
@@ -470,102 +465,9 @@ def parse_args(): | |||
470 | return args | 465 | return args |
471 | 466 | ||
472 | 467 | ||
473 | class Checkpointer(CheckpointerBase): | ||
474 | def __init__( | ||
475 | self, | ||
476 | weight_dtype: torch.dtype, | ||
477 | accelerator: Accelerator, | ||
478 | vae: AutoencoderKL, | ||
479 | unet: UNet2DConditionModel, | ||
480 | ema_unet: EMAModel, | ||
481 | tokenizer: MultiCLIPTokenizer, | ||
482 | text_encoder: CLIPTextModel, | ||
483 | scheduler, | ||
484 | *args, | ||
485 | **kwargs | ||
486 | ): | ||
487 | super().__init__(*args, **kwargs) | ||
488 | |||
489 | self.weight_dtype = weight_dtype | ||
490 | self.accelerator = accelerator | ||
491 | self.vae = vae | ||
492 | self.unet = unet | ||
493 | self.ema_unet = ema_unet | ||
494 | self.tokenizer = tokenizer | ||
495 | self.text_encoder = text_encoder | ||
496 | self.scheduler = scheduler | ||
497 | |||
498 | @torch.no_grad() | ||
499 | def save_model(self): | ||
500 | print("Saving model...") | ||
501 | |||
502 | unet = self.accelerator.unwrap_model(self.unet) | ||
503 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
504 | |||
505 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() | ||
506 | |||
507 | with ema_context: | ||
508 | pipeline = VlpnStableDiffusion( | ||
509 | text_encoder=text_encoder, | ||
510 | vae=self.vae, | ||
511 | unet=unet, | ||
512 | tokenizer=self.tokenizer, | ||
513 | scheduler=self.scheduler, | ||
514 | ) | ||
515 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | ||
516 | |||
517 | del unet | ||
518 | del text_encoder | ||
519 | del pipeline | ||
520 | |||
521 | if torch.cuda.is_available(): | ||
522 | torch.cuda.empty_cache() | ||
523 | |||
524 | @torch.no_grad() | ||
525 | def save_samples(self, step): | ||
526 | unet = self.accelerator.unwrap_model(self.unet) | ||
527 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
528 | |||
529 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() | ||
530 | |||
531 | with ema_context: | ||
532 | orig_unet_dtype = unet.dtype | ||
533 | orig_text_encoder_dtype = text_encoder.dtype | ||
534 | |||
535 | unet.to(dtype=self.weight_dtype) | ||
536 | text_encoder.to(dtype=self.weight_dtype) | ||
537 | |||
538 | pipeline = VlpnStableDiffusion( | ||
539 | text_encoder=text_encoder, | ||
540 | vae=self.vae, | ||
541 | unet=unet, | ||
542 | tokenizer=self.tokenizer, | ||
543 | scheduler=self.scheduler, | ||
544 | ).to(self.accelerator.device) | ||
545 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
546 | |||
547 | super().save_samples(pipeline, step) | ||
548 | |||
549 | unet.to(dtype=orig_unet_dtype) | ||
550 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
551 | |||
552 | del unet | ||
553 | del text_encoder | ||
554 | del pipeline | ||
555 | |||
556 | if torch.cuda.is_available(): | ||
557 | torch.cuda.empty_cache() | ||
558 | |||
559 | |||
560 | def main(): | 468 | def main(): |
561 | args = parse_args() | 469 | args = parse_args() |
562 | 470 | ||
563 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
564 | raise ValueError( | ||
565 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
566 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
567 | ) | ||
568 | |||
569 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 471 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
570 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 472 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
571 | output_dir.mkdir(parents=True, exist_ok=True) | 473 | output_dir.mkdir(parents=True, exist_ok=True) |
@@ -621,41 +523,12 @@ def main(): | |||
621 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 523 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) |
622 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 524 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") |
623 | 525 | ||
624 | if args.use_ema: | ||
625 | ema_unet = EMAModel( | ||
626 | unet.parameters(), | ||
627 | inv_gamma=args.ema_inv_gamma, | ||
628 | power=args.ema_power, | ||
629 | max_value=args.ema_max_decay, | ||
630 | ) | ||
631 | else: | ||
632 | ema_unet = None | ||
633 | |||
634 | vae.requires_grad_(False) | ||
635 | |||
636 | if args.train_text_encoder: | ||
637 | print(f"Training entire text encoder.") | ||
638 | |||
639 | embeddings.persist() | ||
640 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
641 | else: | ||
642 | print(f"Training added text embeddings") | ||
643 | |||
644 | text_encoder.text_model.encoder.requires_grad_(False) | ||
645 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | ||
646 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
647 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
648 | |||
649 | if args.scale_lr: | 526 | if args.scale_lr: |
650 | args.learning_rate = ( | 527 | args.learning_rate = ( |
651 | args.learning_rate * args.gradient_accumulation_steps * | 528 | args.learning_rate * args.gradient_accumulation_steps * |
652 | args.train_batch_size * accelerator.num_processes | 529 | args.train_batch_size * accelerator.num_processes |
653 | ) | 530 | ) |
654 | 531 | ||
655 | if args.find_lr: | ||
656 | args.learning_rate = 1e-6 | ||
657 | |||
658 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
659 | if args.use_8bit_adam: | 532 | if args.use_8bit_adam: |
660 | try: | 533 | try: |
661 | import bitsandbytes as bnb | 534 | import bitsandbytes as bnb |
@@ -666,41 +539,30 @@ def main(): | |||
666 | else: | 539 | else: |
667 | optimizer_class = torch.optim.AdamW | 540 | optimizer_class = torch.optim.AdamW |
668 | 541 | ||
669 | if args.train_text_encoder: | ||
670 | text_encoder_params_to_optimize = text_encoder.parameters() | ||
671 | else: | ||
672 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
673 | |||
674 | # Initialize the optimizer | ||
675 | optimizer = optimizer_class( | ||
676 | [ | ||
677 | { | ||
678 | 'params': unet.parameters(), | ||
679 | }, | ||
680 | { | ||
681 | 'params': text_encoder_params_to_optimize, | ||
682 | } | ||
683 | ], | ||
684 | lr=args.learning_rate, | ||
685 | betas=(args.adam_beta1, args.adam_beta2), | ||
686 | weight_decay=args.adam_weight_decay, | ||
687 | eps=args.adam_epsilon, | ||
688 | amsgrad=args.adam_amsgrad, | ||
689 | ) | ||
690 | |||
691 | weight_dtype = torch.float32 | 542 | weight_dtype = torch.float32 |
692 | if args.mixed_precision == "fp16": | 543 | if args.mixed_precision == "fp16": |
693 | weight_dtype = torch.float16 | 544 | weight_dtype = torch.float16 |
694 | elif args.mixed_precision == "bf16": | 545 | elif args.mixed_precision == "bf16": |
695 | weight_dtype = torch.bfloat16 | 546 | weight_dtype = torch.bfloat16 |
696 | 547 | ||
697 | def keyword_filter(item: VlpnDataItem): | 548 | trainer = partial( |
698 | cond3 = args.collection is None or args.collection in item.collection | 549 | train, |
699 | cond4 = args.exclude_collections is None or not any( | 550 | accelerator=accelerator, |
700 | collection in item.collection | 551 | unet=unet, |
701 | for collection in args.exclude_collections | 552 | text_encoder=text_encoder, |
702 | ) | 553 | vae=vae, |
703 | return cond3 and cond4 | 554 | noise_scheduler=noise_scheduler, |
555 | dtype=weight_dtype, | ||
556 | seed=args.seed, | ||
557 | callbacks_fn=textual_inversion_strategy | ||
558 | ) | ||
559 | |||
560 | # Initial TI | ||
561 | |||
562 | print("Phase 1: Textual Inversion") | ||
563 | |||
564 | cur_dir = output_dir.joinpath("1-ti") | ||
565 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
704 | 566 | ||
705 | datamodule = VlpnDataModule( | 567 | datamodule = VlpnDataModule( |
706 | data_file=args.train_data_file, | 568 | data_file=args.train_data_file, |
@@ -709,182 +571,146 @@ def main(): | |||
709 | class_subdir=args.class_image_dir, | 571 | class_subdir=args.class_image_dir, |
710 | num_class_images=args.num_class_images, | 572 | num_class_images=args.num_class_images, |
711 | size=args.resolution, | 573 | size=args.resolution, |
712 | num_buckets=args.num_buckets, | ||
713 | progressive_buckets=args.progressive_buckets, | ||
714 | bucket_step_size=args.bucket_step_size, | ||
715 | bucket_max_pixels=args.bucket_max_pixels, | ||
716 | dropout=args.tag_dropout, | ||
717 | shuffle=not args.no_tag_shuffle, | 574 | shuffle=not args.no_tag_shuffle, |
718 | template_key=args.train_data_template, | 575 | template_key=args.train_data_template, |
719 | valid_set_size=args.valid_set_size, | 576 | valid_set_size=args.valid_set_size, |
720 | valid_set_repeat=args.valid_set_repeat, | 577 | valid_set_repeat=args.valid_set_repeat, |
721 | seed=args.seed, | 578 | seed=args.seed, |
722 | filter=keyword_filter, | 579 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
723 | dtype=weight_dtype | 580 | dtype=weight_dtype |
724 | ) | 581 | ) |
725 | datamodule.setup() | 582 | datamodule.setup() |
726 | 583 | ||
727 | train_dataloader = datamodule.train_dataloader | 584 | optimizer = optimizer_class( |
728 | val_dataloader = datamodule.val_dataloader | 585 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
729 | 586 | lr=2e-1, | |
730 | if args.num_class_images != 0: | 587 | weight_decay=0.0, |
731 | generate_class_images( | 588 | ) |
732 | accelerator, | ||
733 | text_encoder, | ||
734 | vae, | ||
735 | unet, | ||
736 | tokenizer, | ||
737 | sample_scheduler, | ||
738 | datamodule.data_train, | ||
739 | args.sample_batch_size, | ||
740 | args.sample_image_size, | ||
741 | args.sample_steps | ||
742 | ) | ||
743 | |||
744 | if args.find_lr: | ||
745 | lr_scheduler = None | ||
746 | else: | ||
747 | lr_scheduler = get_scheduler( | ||
748 | args.lr_scheduler, | ||
749 | optimizer=optimizer, | ||
750 | min_lr=args.lr_min_lr, | ||
751 | warmup_func=args.lr_warmup_func, | ||
752 | annealing_func=args.lr_annealing_func, | ||
753 | warmup_exp=args.lr_warmup_exp, | ||
754 | annealing_exp=args.lr_annealing_exp, | ||
755 | cycles=args.lr_cycles, | ||
756 | train_epochs=args.num_train_epochs, | ||
757 | warmup_epochs=args.lr_warmup_epochs, | ||
758 | num_training_steps_per_epoch=len(train_dataloader), | ||
759 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
760 | ) | ||
761 | 589 | ||
762 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 590 | lr_scheduler = get_scheduler( |
763 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 591 | "linear", |
592 | optimizer=optimizer, | ||
593 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
594 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
595 | train_epochs=30, | ||
596 | warmup_epochs=10, | ||
597 | ) | ||
598 | |||
599 | trainer( | ||
600 | project="textual_inversion", | ||
601 | train_dataloader=datamodule.train_dataloader, | ||
602 | val_dataloader=datamodule.val_dataloader, | ||
603 | optimizer=optimizer, | ||
604 | lr_scheduler=lr_scheduler, | ||
605 | num_train_epochs=30, | ||
606 | sample_frequency=5, | ||
607 | checkpoint_frequency=9999999, | ||
608 | with_prior_preservation=args.num_class_images != 0, | ||
609 | prior_loss_weight=args.prior_loss_weight, | ||
610 | # -- | ||
611 | tokenizer=tokenizer, | ||
612 | sample_scheduler=sample_scheduler, | ||
613 | output_dir=cur_dir, | ||
614 | placeholder_tokens=args.placeholder_tokens, | ||
615 | placeholder_token_ids=placeholder_token_ids, | ||
616 | learning_rate=2e-1, | ||
617 | gradient_checkpointing=args.gradient_checkpointing, | ||
618 | use_emb_decay=True, | ||
619 | sample_batch_size=args.sample_batch_size, | ||
620 | sample_num_batches=args.sample_batches, | ||
621 | sample_num_steps=args.sample_steps, | ||
622 | sample_image_size=args.sample_image_size, | ||
764 | ) | 623 | ) |
765 | 624 | ||
766 | vae.to(accelerator.device, dtype=weight_dtype) | 625 | # Dreambooth |
767 | 626 | ||
768 | if args.use_ema: | 627 | print("Phase 2: Dreambooth") |
769 | ema_unet.to(accelerator.device) | ||
770 | 628 | ||
771 | @contextmanager | 629 | cur_dir = output_dir.joinpath("2db") |
772 | def on_train(epoch: int): | 630 | cur_dir.mkdir(parents=True, exist_ok=True) |
773 | try: | ||
774 | tokenizer.train() | ||
775 | 631 | ||
776 | if epoch < args.train_text_encoder_epochs: | 632 | args.seed = (args.seed + 28635) >> 32 |
777 | text_encoder.train() | ||
778 | elif epoch == args.train_text_encoder_epochs: | ||
779 | text_encoder.requires_grad_(False) | ||
780 | 633 | ||
781 | yield | 634 | datamodule = VlpnDataModule( |
782 | finally: | 635 | data_file=args.train_data_file, |
783 | pass | 636 | batch_size=args.train_batch_size, |
637 | tokenizer=tokenizer, | ||
638 | class_subdir=args.class_image_dir, | ||
639 | num_class_images=args.num_class_images, | ||
640 | size=args.resolution, | ||
641 | num_buckets=args.num_buckets, | ||
642 | progressive_buckets=args.progressive_buckets, | ||
643 | bucket_step_size=args.bucket_step_size, | ||
644 | bucket_max_pixels=args.bucket_max_pixels, | ||
645 | dropout=args.tag_dropout, | ||
646 | shuffle=not args.no_tag_shuffle, | ||
647 | template_key=args.train_data_template, | ||
648 | valid_set_size=args.valid_set_size, | ||
649 | valid_set_repeat=args.valid_set_repeat, | ||
650 | seed=args.seed, | ||
651 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
652 | dtype=weight_dtype | ||
653 | ) | ||
654 | datamodule.setup() | ||
784 | 655 | ||
785 | @contextmanager | 656 | optimizer = optimizer_class( |
786 | def on_eval(): | 657 | [ |
787 | try: | 658 | { |
788 | tokenizer.eval() | 659 | 'params': unet.parameters(), |
789 | text_encoder.eval() | 660 | }, |
790 | 661 | { | |
791 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() | 662 | 'params': text_encoder.parameters(), |
792 | 663 | } | |
793 | with ema_context: | 664 | ], |
794 | yield | 665 | lr=args.learning_rate, |
795 | finally: | 666 | betas=(args.adam_beta1, args.adam_beta2), |
796 | pass | 667 | weight_decay=args.adam_weight_decay, |
797 | 668 | eps=args.adam_epsilon, | |
798 | def on_before_optimize(epoch: int): | 669 | amsgrad=args.adam_amsgrad, |
799 | if accelerator.sync_gradients: | 670 | ) |
800 | params_to_clip = [unet.parameters()] | 671 | |
801 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | 672 | lr_scheduler = get_scheduler( |
802 | params_to_clip.append(text_encoder.parameters()) | 673 | args.lr_scheduler, |
803 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | 674 | optimizer=optimizer, |
804 | 675 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | |
805 | @torch.no_grad() | 676 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
806 | def on_after_optimize(lr: float): | 677 | min_lr=args.lr_min_lr, |
807 | if not args.train_text_encoder: | 678 | warmup_func=args.lr_warmup_func, |
808 | text_encoder.text_model.embeddings.normalize( | 679 | annealing_func=args.lr_annealing_func, |
809 | args.decay_target, | 680 | warmup_exp=args.lr_warmup_exp, |
810 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) | 681 | annealing_exp=args.lr_annealing_exp, |
811 | ) | 682 | cycles=args.lr_cycles, |
812 | 683 | train_epochs=args.num_train_epochs, | |
813 | def on_log(): | 684 | warmup_epochs=args.lr_warmup_epochs, |
814 | if args.use_ema: | 685 | ) |
815 | return {"ema_decay": ema_unet.decay} | 686 | |
816 | return {} | 687 | trainer( |
817 | 688 | project="dreambooth", | |
818 | loss_step_ = partial( | 689 | train_dataloader=datamodule.train_dataloader, |
819 | loss_step, | 690 | val_dataloader=datamodule.val_dataloader, |
820 | vae, | 691 | optimizer=optimizer, |
821 | noise_scheduler, | 692 | lr_scheduler=lr_scheduler, |
822 | unet, | 693 | num_train_epochs=args.num_train_epochs, |
823 | text_encoder, | 694 | sample_frequency=args.sample_frequency, |
824 | args.prior_loss_weight, | 695 | checkpoint_frequency=args.checkpoint_frequency, |
825 | args.seed, | 696 | with_prior_preservation=args.num_class_images != 0, |
826 | ) | 697 | prior_loss_weight=args.prior_loss_weight, |
827 | 698 | # -- | |
828 | checkpointer = Checkpointer( | ||
829 | weight_dtype=weight_dtype, | ||
830 | train_dataloader=train_dataloader, | ||
831 | val_dataloader=val_dataloader, | ||
832 | accelerator=accelerator, | ||
833 | vae=vae, | ||
834 | unet=unet, | ||
835 | ema_unet=ema_unet, | ||
836 | tokenizer=tokenizer, | 699 | tokenizer=tokenizer, |
837 | text_encoder=text_encoder, | 700 | sample_scheduler=sample_scheduler, |
838 | scheduler=sample_scheduler, | 701 | output_dir=cur_dir, |
839 | placeholder_tokens=args.placeholder_tokens, | 702 | gradient_checkpointing=args.gradient_checkpointing, |
840 | placeholder_token_ids=placeholder_token_ids, | 703 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
841 | output_dir=output_dir, | 704 | max_grad_norm=args.max_grad_norm, |
842 | sample_steps=args.sample_steps, | 705 | use_ema=args.use_ema, |
843 | sample_image_size=args.sample_image_size, | 706 | ema_inv_gamma=args.ema_inv_gamma, |
707 | ema_power=args.ema_power, | ||
708 | ema_max_decay=args.ema_max_decay, | ||
844 | sample_batch_size=args.sample_batch_size, | 709 | sample_batch_size=args.sample_batch_size, |
845 | sample_batches=args.sample_batches, | 710 | sample_num_batches=args.sample_batches, |
846 | seed=args.seed | 711 | sample_num_steps=args.sample_steps, |
847 | ) | 712 | sample_image_size=args.sample_image_size, |
848 | 713 | ) | |
849 | if accelerator.is_main_process: | ||
850 | accelerator.init_trackers("dreambooth", config=config) | ||
851 | |||
852 | if args.find_lr: | ||
853 | lr_finder = LRFinder( | ||
854 | accelerator=accelerator, | ||
855 | optimizer=optimizer, | ||
856 | model=unet, | ||
857 | train_dataloader=train_dataloader, | ||
858 | val_dataloader=val_dataloader, | ||
859 | loss_step=loss_step_, | ||
860 | on_train=on_train, | ||
861 | on_eval=on_eval, | ||
862 | on_before_optimize=on_before_optimize, | ||
863 | on_after_optimize=on_after_optimize, | ||
864 | ) | ||
865 | lr_finder.run(num_epochs=100, end_lr=1e2) | ||
866 | |||
867 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
868 | plt.close() | ||
869 | else: | ||
870 | train_loop( | ||
871 | accelerator=accelerator, | ||
872 | optimizer=optimizer, | ||
873 | lr_scheduler=lr_scheduler, | ||
874 | model=unet, | ||
875 | checkpointer=checkpointer, | ||
876 | train_dataloader=train_dataloader, | ||
877 | val_dataloader=val_dataloader, | ||
878 | loss_step=loss_step_, | ||
879 | sample_frequency=args.sample_frequency, | ||
880 | checkpoint_frequency=args.checkpoint_frequency, | ||
881 | global_step_offset=0, | ||
882 | num_epochs=args.num_train_epochs, | ||
883 | on_log=on_log, | ||
884 | on_train=on_train, | ||
885 | on_after_optimize=on_after_optimize, | ||
886 | on_eval=on_eval | ||
887 | ) | ||
888 | 714 | ||
889 | 715 | ||
890 | if __name__ == "__main__": | 716 | if __name__ == "__main__": |
diff --git a/train_ti.py b/train_ti.py index 2497519..48a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed | |||
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
@@ -446,15 +446,15 @@ def parse_args(): | |||
446 | if isinstance(args.placeholder_tokens, str): | 446 | if isinstance(args.placeholder_tokens, str): |
447 | args.placeholder_tokens = [args.placeholder_tokens] | 447 | args.placeholder_tokens = [args.placeholder_tokens] |
448 | 448 | ||
449 | if len(args.placeholder_tokens) == 0: | ||
450 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] | ||
451 | |||
452 | if isinstance(args.initializer_tokens, str): | 449 | if isinstance(args.initializer_tokens, str): |
453 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 450 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
454 | 451 | ||
455 | if len(args.initializer_tokens) == 0: | 452 | if len(args.initializer_tokens) == 0: |
456 | raise ValueError("You must specify --initializer_tokens") | 453 | raise ValueError("You must specify --initializer_tokens") |
457 | 454 | ||
455 | if len(args.placeholder_tokens) == 0: | ||
456 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
457 | |||
458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
460 | 460 | ||
@@ -544,9 +544,6 @@ def main(): | |||
544 | args.train_batch_size * accelerator.num_processes | 544 | args.train_batch_size * accelerator.num_processes |
545 | ) | 545 | ) |
546 | 546 | ||
547 | if args.find_lr: | ||
548 | args.learning_rate = 1e-5 | ||
549 | |||
550 | if args.use_8bit_adam: | 547 | if args.use_8bit_adam: |
551 | try: | 548 | try: |
552 | import bitsandbytes as bnb | 549 | import bitsandbytes as bnb |
@@ -563,19 +560,6 @@ def main(): | |||
563 | elif args.mixed_precision == "bf16": | 560 | elif args.mixed_precision == "bf16": |
564 | weight_dtype = torch.bfloat16 | 561 | weight_dtype = torch.bfloat16 |
565 | 562 | ||
566 | def keyword_filter(item: VlpnDataItem): | ||
567 | cond1 = any( | ||
568 | keyword in part | ||
569 | for keyword in args.placeholder_tokens | ||
570 | for part in item.prompt | ||
571 | ) | ||
572 | cond3 = args.collection is None or args.collection in item.collection | ||
573 | cond4 = args.exclude_collections is None or not any( | ||
574 | collection in item.collection | ||
575 | for collection in args.exclude_collections | ||
576 | ) | ||
577 | return cond1 and cond3 and cond4 | ||
578 | |||
579 | datamodule = VlpnDataModule( | 563 | datamodule = VlpnDataModule( |
580 | data_file=args.train_data_file, | 564 | data_file=args.train_data_file, |
581 | batch_size=args.train_batch_size, | 565 | batch_size=args.train_batch_size, |
@@ -593,7 +577,7 @@ def main(): | |||
593 | valid_set_size=args.valid_set_size, | 577 | valid_set_size=args.valid_set_size, |
594 | valid_set_repeat=args.valid_set_repeat, | 578 | valid_set_repeat=args.valid_set_repeat, |
595 | seed=args.seed, | 579 | seed=args.seed, |
596 | filter=keyword_filter, | 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
597 | dtype=weight_dtype | 581 | dtype=weight_dtype |
598 | ) | 582 | ) |
599 | datamodule.setup() | 583 | datamodule.setup() |
@@ -622,8 +606,6 @@ def main(): | |||
622 | text_encoder=text_encoder, | 606 | text_encoder=text_encoder, |
623 | vae=vae, | 607 | vae=vae, |
624 | noise_scheduler=noise_scheduler, | 608 | noise_scheduler=noise_scheduler, |
625 | train_dataloader=train_dataloader, | ||
626 | val_dataloader=val_dataloader, | ||
627 | dtype=weight_dtype, | 609 | dtype=weight_dtype, |
628 | seed=args.seed, | 610 | seed=args.seed, |
629 | callbacks_fn=textual_inversion_strategy | 611 | callbacks_fn=textual_inversion_strategy |
@@ -638,25 +620,25 @@ def main(): | |||
638 | amsgrad=args.adam_amsgrad, | 620 | amsgrad=args.adam_amsgrad, |
639 | ) | 621 | ) |
640 | 622 | ||
641 | if args.find_lr: | 623 | lr_scheduler = get_scheduler( |
642 | lr_scheduler = None | 624 | args.lr_scheduler, |
643 | else: | 625 | optimizer=optimizer, |
644 | lr_scheduler = get_scheduler( | 626 | num_training_steps_per_epoch=len(train_dataloader), |
645 | args.lr_scheduler, | 627 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
646 | optimizer=optimizer, | 628 | min_lr=args.lr_min_lr, |
647 | num_training_steps_per_epoch=len(train_dataloader), | 629 | warmup_func=args.lr_warmup_func, |
648 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 630 | annealing_func=args.lr_annealing_func, |
649 | min_lr=args.lr_min_lr, | 631 | warmup_exp=args.lr_warmup_exp, |
650 | warmup_func=args.lr_warmup_func, | 632 | annealing_exp=args.lr_annealing_exp, |
651 | annealing_func=args.lr_annealing_func, | 633 | cycles=args.lr_cycles, |
652 | warmup_exp=args.lr_warmup_exp, | 634 | train_epochs=args.num_train_epochs, |
653 | annealing_exp=args.lr_annealing_exp, | 635 | warmup_epochs=args.lr_warmup_epochs, |
654 | cycles=args.lr_cycles, | 636 | ) |
655 | train_epochs=args.num_train_epochs, | ||
656 | warmup_epochs=args.lr_warmup_epochs, | ||
657 | ) | ||
658 | 637 | ||
659 | trainer( | 638 | trainer( |
639 | project="textual_inversion", | ||
640 | train_dataloader=train_dataloader, | ||
641 | val_dataloader=val_dataloader, | ||
660 | optimizer=optimizer, | 642 | optimizer=optimizer, |
661 | lr_scheduler=lr_scheduler, | 643 | lr_scheduler=lr_scheduler, |
662 | num_train_epochs=args.num_train_epochs, | 644 | num_train_epochs=args.num_train_epochs, |
diff --git a/training/functional.py b/training/functional.py index 5984ffb..f5c111e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -494,10 +494,11 @@ def train( | |||
494 | text_encoder: CLIPTextModel, | 494 | text_encoder: CLIPTextModel, |
495 | vae: AutoencoderKL, | 495 | vae: AutoencoderKL, |
496 | noise_scheduler: DDPMScheduler, | 496 | noise_scheduler: DDPMScheduler, |
497 | train_dataloader: DataLoader, | ||
498 | val_dataloader: DataLoader, | ||
499 | dtype: torch.dtype, | 497 | dtype: torch.dtype, |
500 | seed: int, | 498 | seed: int, |
499 | project: str, | ||
500 | train_dataloader: DataLoader, | ||
501 | val_dataloader: DataLoader, | ||
501 | optimizer: torch.optim.Optimizer, | 502 | optimizer: torch.optim.Optimizer, |
502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 503 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
503 | callbacks_fn: Callable[..., TrainingCallbacks], | 504 | callbacks_fn: Callable[..., TrainingCallbacks], |
@@ -544,7 +545,7 @@ def train( | |||
544 | ) | 545 | ) |
545 | 546 | ||
546 | if accelerator.is_main_process: | 547 | if accelerator.is_main_process: |
547 | accelerator.init_trackers("textual_inversion") | 548 | accelerator.init_trackers(project) |
548 | 549 | ||
549 | train_loop( | 550 | train_loop( |
550 | accelerator=accelerator, | 551 | accelerator=accelerator, |