diff options
| -rw-r--r-- | data/csv.py | 19 | ||||
| -rw-r--r-- | train_dreambooth.py | 474 | ||||
| -rw-r--r-- | train_ti.py | 62 | ||||
| -rw-r--r-- | training/functional.py | 7 |
4 files changed, 195 insertions, 367 deletions
diff --git a/data/csv.py b/data/csv.py index 2a8115b..2b1e202 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -130,6 +130,25 @@ class VlpnDataItem(NamedTuple): | |||
| 130 | collection: list[str] | 130 | collection: list[str] |
| 131 | 131 | ||
| 132 | 132 | ||
| 133 | def keyword_filter( | ||
| 134 | placeholder_tokens: Optional[list[str]], | ||
| 135 | collection: Optional[list[str]], | ||
| 136 | exclude_collections: Optional[list[str]], | ||
| 137 | item: VlpnDataItem | ||
| 138 | ): | ||
| 139 | cond1 = placeholder_tokens is None or any( | ||
| 140 | keyword in part | ||
| 141 | for keyword in placeholder_tokens | ||
| 142 | for part in item.prompt | ||
| 143 | ) | ||
| 144 | cond2 = collection is None or collection in item.collection | ||
| 145 | cond3 = exclude_collections is None or not any( | ||
| 146 | collection in item.collection | ||
| 147 | for collection in exclude_collections | ||
| 148 | ) | ||
| 149 | return cond1 and cond2 and cond3 | ||
| 150 | |||
| 151 | |||
| 133 | class VlpnDataModule(): | 152 | class VlpnDataModule(): |
| 134 | def __init__( | 153 | def __init__( |
| 135 | self, | 154 | self, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 71bad7e..944256c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -1,10 +1,8 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 3 | import datetime | 2 | import datetime |
| 4 | import logging | 3 | import logging |
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | from functools import partial | 5 | from functools import partial |
| 7 | from contextlib import contextmanager, nullcontext | ||
| 8 | 6 | ||
| 9 | import torch | 7 | import torch |
| 10 | import torch.utils.checkpoint | 8 | import torch.utils.checkpoint |
| @@ -12,18 +10,15 @@ import torch.utils.checkpoint | |||
| 12 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
| 15 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
| 16 | import matplotlib.pyplot as plt | ||
| 17 | from transformers import CLIPTextModel | ||
| 18 | from slugify import slugify | 13 | from slugify import slugify |
| 19 | 14 | ||
| 20 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 22 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 18 | from training.strategy.ti import textual_inversion_strategy | ||
| 19 | from training.strategy.dreambooth import dreambooth_strategy | ||
| 23 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler |
| 24 | from training.lr import LRFinder | 21 | from training.util import save_args |
| 25 | from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models | ||
| 26 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 27 | 22 | ||
| 28 | logger = get_logger(__name__) | 23 | logger = get_logger(__name__) |
| 29 | 24 | ||
| @@ -73,7 +68,7 @@ def parse_args(): | |||
| 73 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
| 74 | ) | 69 | ) |
| 75 | parser.add_argument( | 70 | parser.add_argument( |
| 76 | "--initializer_token", | 71 | "--initializer_tokens", |
| 77 | type=str, | 72 | type=str, |
| 78 | nargs='*', | 73 | nargs='*', |
| 79 | default=[], | 74 | default=[], |
| @@ -151,7 +146,7 @@ def parse_args(): | |||
| 151 | parser.add_argument( | 146 | parser.add_argument( |
| 152 | "--num_class_images", | 147 | "--num_class_images", |
| 153 | type=int, | 148 | type=int, |
| 154 | default=1, | 149 | default=0, |
| 155 | help="How many class images to generate." | 150 | help="How many class images to generate." |
| 156 | ) | 151 | ) |
| 157 | parser.add_argument( | 152 | parser.add_argument( |
| @@ -437,23 +432,23 @@ def parse_args(): | |||
| 437 | if isinstance(args.placeholder_tokens, str): | 432 | if isinstance(args.placeholder_tokens, str): |
| 438 | args.placeholder_tokens = [args.placeholder_tokens] | 433 | args.placeholder_tokens = [args.placeholder_tokens] |
| 439 | 434 | ||
| 440 | if len(args.placeholder_tokens) == 0: | 435 | if isinstance(args.initializer_tokens, str): |
| 441 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] | 436 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 442 | 437 | ||
| 443 | if isinstance(args.initializer_token, str): | 438 | if len(args.initializer_tokens) == 0: |
| 444 | args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) | 439 | raise ValueError("You must specify --initializer_tokens") |
| 445 | 440 | ||
| 446 | if len(args.initializer_token) == 0: | 441 | if len(args.placeholder_tokens) == 0: |
| 447 | raise ValueError("You must specify --initializer_token") | 442 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
| 448 | 443 | ||
| 449 | if len(args.placeholder_tokens) != len(args.initializer_token): | 444 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 450 | raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") | 445 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 451 | 446 | ||
| 452 | if args.num_vectors is None: | 447 | if args.num_vectors is None: |
| 453 | args.num_vectors = 1 | 448 | args.num_vectors = 1 |
| 454 | 449 | ||
| 455 | if isinstance(args.num_vectors, int): | 450 | if isinstance(args.num_vectors, int): |
| 456 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | 451 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) |
| 457 | 452 | ||
| 458 | if len(args.placeholder_tokens) != len(args.num_vectors): | 453 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 459 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 454 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| @@ -470,102 +465,9 @@ def parse_args(): | |||
| 470 | return args | 465 | return args |
| 471 | 466 | ||
| 472 | 467 | ||
| 473 | class Checkpointer(CheckpointerBase): | ||
| 474 | def __init__( | ||
| 475 | self, | ||
| 476 | weight_dtype: torch.dtype, | ||
| 477 | accelerator: Accelerator, | ||
| 478 | vae: AutoencoderKL, | ||
| 479 | unet: UNet2DConditionModel, | ||
| 480 | ema_unet: EMAModel, | ||
| 481 | tokenizer: MultiCLIPTokenizer, | ||
| 482 | text_encoder: CLIPTextModel, | ||
| 483 | scheduler, | ||
| 484 | *args, | ||
| 485 | **kwargs | ||
| 486 | ): | ||
| 487 | super().__init__(*args, **kwargs) | ||
| 488 | |||
| 489 | self.weight_dtype = weight_dtype | ||
| 490 | self.accelerator = accelerator | ||
| 491 | self.vae = vae | ||
| 492 | self.unet = unet | ||
| 493 | self.ema_unet = ema_unet | ||
| 494 | self.tokenizer = tokenizer | ||
| 495 | self.text_encoder = text_encoder | ||
| 496 | self.scheduler = scheduler | ||
| 497 | |||
| 498 | @torch.no_grad() | ||
| 499 | def save_model(self): | ||
| 500 | print("Saving model...") | ||
| 501 | |||
| 502 | unet = self.accelerator.unwrap_model(self.unet) | ||
| 503 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 504 | |||
| 505 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() | ||
| 506 | |||
| 507 | with ema_context: | ||
| 508 | pipeline = VlpnStableDiffusion( | ||
| 509 | text_encoder=text_encoder, | ||
| 510 | vae=self.vae, | ||
| 511 | unet=unet, | ||
| 512 | tokenizer=self.tokenizer, | ||
| 513 | scheduler=self.scheduler, | ||
| 514 | ) | ||
| 515 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | ||
| 516 | |||
| 517 | del unet | ||
| 518 | del text_encoder | ||
| 519 | del pipeline | ||
| 520 | |||
| 521 | if torch.cuda.is_available(): | ||
| 522 | torch.cuda.empty_cache() | ||
| 523 | |||
| 524 | @torch.no_grad() | ||
| 525 | def save_samples(self, step): | ||
| 526 | unet = self.accelerator.unwrap_model(self.unet) | ||
| 527 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 528 | |||
| 529 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() | ||
| 530 | |||
| 531 | with ema_context: | ||
| 532 | orig_unet_dtype = unet.dtype | ||
| 533 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 534 | |||
| 535 | unet.to(dtype=self.weight_dtype) | ||
| 536 | text_encoder.to(dtype=self.weight_dtype) | ||
| 537 | |||
| 538 | pipeline = VlpnStableDiffusion( | ||
| 539 | text_encoder=text_encoder, | ||
| 540 | vae=self.vae, | ||
| 541 | unet=unet, | ||
| 542 | tokenizer=self.tokenizer, | ||
| 543 | scheduler=self.scheduler, | ||
| 544 | ).to(self.accelerator.device) | ||
| 545 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 546 | |||
| 547 | super().save_samples(pipeline, step) | ||
| 548 | |||
| 549 | unet.to(dtype=orig_unet_dtype) | ||
| 550 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 551 | |||
| 552 | del unet | ||
| 553 | del text_encoder | ||
| 554 | del pipeline | ||
| 555 | |||
| 556 | if torch.cuda.is_available(): | ||
| 557 | torch.cuda.empty_cache() | ||
| 558 | |||
| 559 | |||
| 560 | def main(): | 468 | def main(): |
| 561 | args = parse_args() | 469 | args = parse_args() |
| 562 | 470 | ||
| 563 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
| 564 | raise ValueError( | ||
| 565 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
| 566 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
| 567 | ) | ||
| 568 | |||
| 569 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 471 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 570 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 472 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 571 | output_dir.mkdir(parents=True, exist_ok=True) | 473 | output_dir.mkdir(parents=True, exist_ok=True) |
| @@ -621,41 +523,12 @@ def main(): | |||
| 621 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 523 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) |
| 622 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 524 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") |
| 623 | 525 | ||
| 624 | if args.use_ema: | ||
| 625 | ema_unet = EMAModel( | ||
| 626 | unet.parameters(), | ||
| 627 | inv_gamma=args.ema_inv_gamma, | ||
| 628 | power=args.ema_power, | ||
| 629 | max_value=args.ema_max_decay, | ||
| 630 | ) | ||
| 631 | else: | ||
| 632 | ema_unet = None | ||
| 633 | |||
| 634 | vae.requires_grad_(False) | ||
| 635 | |||
| 636 | if args.train_text_encoder: | ||
| 637 | print(f"Training entire text encoder.") | ||
| 638 | |||
| 639 | embeddings.persist() | ||
| 640 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
| 641 | else: | ||
| 642 | print(f"Training added text embeddings") | ||
| 643 | |||
| 644 | text_encoder.text_model.encoder.requires_grad_(False) | ||
| 645 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | ||
| 646 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | ||
| 647 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | ||
| 648 | |||
| 649 | if args.scale_lr: | 526 | if args.scale_lr: |
| 650 | args.learning_rate = ( | 527 | args.learning_rate = ( |
| 651 | args.learning_rate * args.gradient_accumulation_steps * | 528 | args.learning_rate * args.gradient_accumulation_steps * |
| 652 | args.train_batch_size * accelerator.num_processes | 529 | args.train_batch_size * accelerator.num_processes |
| 653 | ) | 530 | ) |
| 654 | 531 | ||
| 655 | if args.find_lr: | ||
| 656 | args.learning_rate = 1e-6 | ||
| 657 | |||
| 658 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
| 659 | if args.use_8bit_adam: | 532 | if args.use_8bit_adam: |
| 660 | try: | 533 | try: |
| 661 | import bitsandbytes as bnb | 534 | import bitsandbytes as bnb |
| @@ -666,41 +539,30 @@ def main(): | |||
| 666 | else: | 539 | else: |
| 667 | optimizer_class = torch.optim.AdamW | 540 | optimizer_class = torch.optim.AdamW |
| 668 | 541 | ||
| 669 | if args.train_text_encoder: | ||
| 670 | text_encoder_params_to_optimize = text_encoder.parameters() | ||
| 671 | else: | ||
| 672 | text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 673 | |||
| 674 | # Initialize the optimizer | ||
| 675 | optimizer = optimizer_class( | ||
| 676 | [ | ||
| 677 | { | ||
| 678 | 'params': unet.parameters(), | ||
| 679 | }, | ||
| 680 | { | ||
| 681 | 'params': text_encoder_params_to_optimize, | ||
| 682 | } | ||
| 683 | ], | ||
| 684 | lr=args.learning_rate, | ||
| 685 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 686 | weight_decay=args.adam_weight_decay, | ||
| 687 | eps=args.adam_epsilon, | ||
| 688 | amsgrad=args.adam_amsgrad, | ||
| 689 | ) | ||
| 690 | |||
| 691 | weight_dtype = torch.float32 | 542 | weight_dtype = torch.float32 |
| 692 | if args.mixed_precision == "fp16": | 543 | if args.mixed_precision == "fp16": |
| 693 | weight_dtype = torch.float16 | 544 | weight_dtype = torch.float16 |
| 694 | elif args.mixed_precision == "bf16": | 545 | elif args.mixed_precision == "bf16": |
| 695 | weight_dtype = torch.bfloat16 | 546 | weight_dtype = torch.bfloat16 |
| 696 | 547 | ||
| 697 | def keyword_filter(item: VlpnDataItem): | 548 | trainer = partial( |
| 698 | cond3 = args.collection is None or args.collection in item.collection | 549 | train, |
| 699 | cond4 = args.exclude_collections is None or not any( | 550 | accelerator=accelerator, |
| 700 | collection in item.collection | 551 | unet=unet, |
| 701 | for collection in args.exclude_collections | 552 | text_encoder=text_encoder, |
| 702 | ) | 553 | vae=vae, |
| 703 | return cond3 and cond4 | 554 | noise_scheduler=noise_scheduler, |
| 555 | dtype=weight_dtype, | ||
| 556 | seed=args.seed, | ||
| 557 | callbacks_fn=textual_inversion_strategy | ||
| 558 | ) | ||
| 559 | |||
| 560 | # Initial TI | ||
| 561 | |||
| 562 | print("Phase 1: Textual Inversion") | ||
| 563 | |||
| 564 | cur_dir = output_dir.joinpath("1-ti") | ||
| 565 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
| 704 | 566 | ||
| 705 | datamodule = VlpnDataModule( | 567 | datamodule = VlpnDataModule( |
| 706 | data_file=args.train_data_file, | 568 | data_file=args.train_data_file, |
| @@ -709,183 +571,147 @@ def main(): | |||
| 709 | class_subdir=args.class_image_dir, | 571 | class_subdir=args.class_image_dir, |
| 710 | num_class_images=args.num_class_images, | 572 | num_class_images=args.num_class_images, |
| 711 | size=args.resolution, | 573 | size=args.resolution, |
| 712 | num_buckets=args.num_buckets, | ||
| 713 | progressive_buckets=args.progressive_buckets, | ||
| 714 | bucket_step_size=args.bucket_step_size, | ||
| 715 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 716 | dropout=args.tag_dropout, | ||
| 717 | shuffle=not args.no_tag_shuffle, | 574 | shuffle=not args.no_tag_shuffle, |
| 718 | template_key=args.train_data_template, | 575 | template_key=args.train_data_template, |
| 719 | valid_set_size=args.valid_set_size, | 576 | valid_set_size=args.valid_set_size, |
| 720 | valid_set_repeat=args.valid_set_repeat, | 577 | valid_set_repeat=args.valid_set_repeat, |
| 721 | seed=args.seed, | 578 | seed=args.seed, |
| 722 | filter=keyword_filter, | 579 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 723 | dtype=weight_dtype | 580 | dtype=weight_dtype |
| 724 | ) | 581 | ) |
| 725 | datamodule.setup() | 582 | datamodule.setup() |
| 726 | 583 | ||
| 727 | train_dataloader = datamodule.train_dataloader | 584 | optimizer = optimizer_class( |
| 728 | val_dataloader = datamodule.val_dataloader | 585 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 729 | 586 | lr=2e-1, | |
| 730 | if args.num_class_images != 0: | 587 | weight_decay=0.0, |
| 731 | generate_class_images( | ||
| 732 | accelerator, | ||
| 733 | text_encoder, | ||
| 734 | vae, | ||
| 735 | unet, | ||
| 736 | tokenizer, | ||
| 737 | sample_scheduler, | ||
| 738 | datamodule.data_train, | ||
| 739 | args.sample_batch_size, | ||
| 740 | args.sample_image_size, | ||
| 741 | args.sample_steps | ||
| 742 | ) | ||
| 743 | |||
| 744 | if args.find_lr: | ||
| 745 | lr_scheduler = None | ||
| 746 | else: | ||
| 747 | lr_scheduler = get_scheduler( | ||
| 748 | args.lr_scheduler, | ||
| 749 | optimizer=optimizer, | ||
| 750 | min_lr=args.lr_min_lr, | ||
| 751 | warmup_func=args.lr_warmup_func, | ||
| 752 | annealing_func=args.lr_annealing_func, | ||
| 753 | warmup_exp=args.lr_warmup_exp, | ||
| 754 | annealing_exp=args.lr_annealing_exp, | ||
| 755 | cycles=args.lr_cycles, | ||
| 756 | train_epochs=args.num_train_epochs, | ||
| 757 | warmup_epochs=args.lr_warmup_epochs, | ||
| 758 | num_training_steps_per_epoch=len(train_dataloader), | ||
| 759 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
| 760 | ) | ||
| 761 | |||
| 762 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 763 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 764 | ) | 588 | ) |
| 765 | 589 | ||
| 766 | vae.to(accelerator.device, dtype=weight_dtype) | 590 | lr_scheduler = get_scheduler( |
| 767 | 591 | "linear", | |
| 768 | if args.use_ema: | 592 | optimizer=optimizer, |
| 769 | ema_unet.to(accelerator.device) | 593 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 770 | 594 | gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| 771 | @contextmanager | 595 | train_epochs=30, |
| 772 | def on_train(epoch: int): | 596 | warmup_epochs=10, |
| 773 | try: | 597 | ) |
| 774 | tokenizer.train() | ||
| 775 | |||
| 776 | if epoch < args.train_text_encoder_epochs: | ||
| 777 | text_encoder.train() | ||
| 778 | elif epoch == args.train_text_encoder_epochs: | ||
| 779 | text_encoder.requires_grad_(False) | ||
| 780 | 598 | ||
| 781 | yield | 599 | trainer( |
| 782 | finally: | 600 | project="textual_inversion", |
| 783 | pass | 601 | train_dataloader=datamodule.train_dataloader, |
| 602 | val_dataloader=datamodule.val_dataloader, | ||
| 603 | optimizer=optimizer, | ||
| 604 | lr_scheduler=lr_scheduler, | ||
| 605 | num_train_epochs=30, | ||
| 606 | sample_frequency=5, | ||
| 607 | checkpoint_frequency=9999999, | ||
| 608 | with_prior_preservation=args.num_class_images != 0, | ||
| 609 | prior_loss_weight=args.prior_loss_weight, | ||
| 610 | # -- | ||
| 611 | tokenizer=tokenizer, | ||
| 612 | sample_scheduler=sample_scheduler, | ||
| 613 | output_dir=cur_dir, | ||
| 614 | placeholder_tokens=args.placeholder_tokens, | ||
| 615 | placeholder_token_ids=placeholder_token_ids, | ||
| 616 | learning_rate=2e-1, | ||
| 617 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 618 | use_emb_decay=True, | ||
| 619 | sample_batch_size=args.sample_batch_size, | ||
| 620 | sample_num_batches=args.sample_batches, | ||
| 621 | sample_num_steps=args.sample_steps, | ||
| 622 | sample_image_size=args.sample_image_size, | ||
| 623 | ) | ||
| 784 | 624 | ||
| 785 | @contextmanager | 625 | # Dreambooth |
| 786 | def on_eval(): | ||
| 787 | try: | ||
| 788 | tokenizer.eval() | ||
| 789 | text_encoder.eval() | ||
| 790 | 626 | ||
| 791 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() | 627 | print("Phase 2: Dreambooth") |
| 792 | 628 | ||
| 793 | with ema_context: | 629 | cur_dir = output_dir.joinpath("2db") |
| 794 | yield | 630 | cur_dir.mkdir(parents=True, exist_ok=True) |
| 795 | finally: | ||
| 796 | pass | ||
| 797 | 631 | ||
| 798 | def on_before_optimize(epoch: int): | 632 | args.seed = (args.seed + 28635) >> 32 |
| 799 | if accelerator.sync_gradients: | ||
| 800 | params_to_clip = [unet.parameters()] | ||
| 801 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs: | ||
| 802 | params_to_clip.append(text_encoder.parameters()) | ||
| 803 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) | ||
| 804 | 633 | ||
| 805 | @torch.no_grad() | 634 | datamodule = VlpnDataModule( |
| 806 | def on_after_optimize(lr: float): | 635 | data_file=args.train_data_file, |
| 807 | if not args.train_text_encoder: | 636 | batch_size=args.train_batch_size, |
| 808 | text_encoder.text_model.embeddings.normalize( | 637 | tokenizer=tokenizer, |
| 809 | args.decay_target, | 638 | class_subdir=args.class_image_dir, |
| 810 | min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) | 639 | num_class_images=args.num_class_images, |
| 811 | ) | 640 | size=args.resolution, |
| 641 | num_buckets=args.num_buckets, | ||
| 642 | progressive_buckets=args.progressive_buckets, | ||
| 643 | bucket_step_size=args.bucket_step_size, | ||
| 644 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 645 | dropout=args.tag_dropout, | ||
| 646 | shuffle=not args.no_tag_shuffle, | ||
| 647 | template_key=args.train_data_template, | ||
| 648 | valid_set_size=args.valid_set_size, | ||
| 649 | valid_set_repeat=args.valid_set_repeat, | ||
| 650 | seed=args.seed, | ||
| 651 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
| 652 | dtype=weight_dtype | ||
| 653 | ) | ||
| 654 | datamodule.setup() | ||
| 812 | 655 | ||
| 813 | def on_log(): | 656 | optimizer = optimizer_class( |
| 814 | if args.use_ema: | 657 | [ |
| 815 | return {"ema_decay": ema_unet.decay} | 658 | { |
| 816 | return {} | 659 | 'params': unet.parameters(), |
| 660 | }, | ||
| 661 | { | ||
| 662 | 'params': text_encoder.parameters(), | ||
| 663 | } | ||
| 664 | ], | ||
| 665 | lr=args.learning_rate, | ||
| 666 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 667 | weight_decay=args.adam_weight_decay, | ||
| 668 | eps=args.adam_epsilon, | ||
| 669 | amsgrad=args.adam_amsgrad, | ||
| 670 | ) | ||
| 817 | 671 | ||
| 818 | loss_step_ = partial( | 672 | lr_scheduler = get_scheduler( |
| 819 | loss_step, | 673 | args.lr_scheduler, |
| 820 | vae, | 674 | optimizer=optimizer, |
| 821 | noise_scheduler, | 675 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 822 | unet, | 676 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 823 | text_encoder, | 677 | min_lr=args.lr_min_lr, |
| 824 | args.prior_loss_weight, | 678 | warmup_func=args.lr_warmup_func, |
| 825 | args.seed, | 679 | annealing_func=args.lr_annealing_func, |
| 680 | warmup_exp=args.lr_warmup_exp, | ||
| 681 | annealing_exp=args.lr_annealing_exp, | ||
| 682 | cycles=args.lr_cycles, | ||
| 683 | train_epochs=args.num_train_epochs, | ||
| 684 | warmup_epochs=args.lr_warmup_epochs, | ||
| 826 | ) | 685 | ) |
| 827 | 686 | ||
| 828 | checkpointer = Checkpointer( | 687 | trainer( |
| 829 | weight_dtype=weight_dtype, | 688 | project="dreambooth", |
| 830 | train_dataloader=train_dataloader, | 689 | train_dataloader=datamodule.train_dataloader, |
| 831 | val_dataloader=val_dataloader, | 690 | val_dataloader=datamodule.val_dataloader, |
| 832 | accelerator=accelerator, | 691 | optimizer=optimizer, |
| 833 | vae=vae, | 692 | lr_scheduler=lr_scheduler, |
| 834 | unet=unet, | 693 | num_train_epochs=args.num_train_epochs, |
| 835 | ema_unet=ema_unet, | 694 | sample_frequency=args.sample_frequency, |
| 695 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 696 | with_prior_preservation=args.num_class_images != 0, | ||
| 697 | prior_loss_weight=args.prior_loss_weight, | ||
| 698 | # -- | ||
| 836 | tokenizer=tokenizer, | 699 | tokenizer=tokenizer, |
| 837 | text_encoder=text_encoder, | 700 | sample_scheduler=sample_scheduler, |
| 838 | scheduler=sample_scheduler, | 701 | output_dir=cur_dir, |
| 839 | placeholder_tokens=args.placeholder_tokens, | 702 | gradient_checkpointing=args.gradient_checkpointing, |
| 840 | placeholder_token_ids=placeholder_token_ids, | 703 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
| 841 | output_dir=output_dir, | 704 | max_grad_norm=args.max_grad_norm, |
| 842 | sample_steps=args.sample_steps, | 705 | use_ema=args.use_ema, |
| 843 | sample_image_size=args.sample_image_size, | 706 | ema_inv_gamma=args.ema_inv_gamma, |
| 707 | ema_power=args.ema_power, | ||
| 708 | ema_max_decay=args.ema_max_decay, | ||
| 844 | sample_batch_size=args.sample_batch_size, | 709 | sample_batch_size=args.sample_batch_size, |
| 845 | sample_batches=args.sample_batches, | 710 | sample_num_batches=args.sample_batches, |
| 846 | seed=args.seed | 711 | sample_num_steps=args.sample_steps, |
| 712 | sample_image_size=args.sample_image_size, | ||
| 847 | ) | 713 | ) |
| 848 | 714 | ||
| 849 | if accelerator.is_main_process: | ||
| 850 | accelerator.init_trackers("dreambooth", config=config) | ||
| 851 | |||
| 852 | if args.find_lr: | ||
| 853 | lr_finder = LRFinder( | ||
| 854 | accelerator=accelerator, | ||
| 855 | optimizer=optimizer, | ||
| 856 | model=unet, | ||
| 857 | train_dataloader=train_dataloader, | ||
| 858 | val_dataloader=val_dataloader, | ||
| 859 | loss_step=loss_step_, | ||
| 860 | on_train=on_train, | ||
| 861 | on_eval=on_eval, | ||
| 862 | on_before_optimize=on_before_optimize, | ||
| 863 | on_after_optimize=on_after_optimize, | ||
| 864 | ) | ||
| 865 | lr_finder.run(num_epochs=100, end_lr=1e2) | ||
| 866 | |||
| 867 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
| 868 | plt.close() | ||
| 869 | else: | ||
| 870 | train_loop( | ||
| 871 | accelerator=accelerator, | ||
| 872 | optimizer=optimizer, | ||
| 873 | lr_scheduler=lr_scheduler, | ||
| 874 | model=unet, | ||
| 875 | checkpointer=checkpointer, | ||
| 876 | train_dataloader=train_dataloader, | ||
| 877 | val_dataloader=val_dataloader, | ||
| 878 | loss_step=loss_step_, | ||
| 879 | sample_frequency=args.sample_frequency, | ||
| 880 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 881 | global_step_offset=0, | ||
| 882 | num_epochs=args.num_train_epochs, | ||
| 883 | on_log=on_log, | ||
| 884 | on_train=on_train, | ||
| 885 | on_after_optimize=on_after_optimize, | ||
| 886 | on_eval=on_eval | ||
| 887 | ) | ||
| 888 | |||
| 889 | 715 | ||
| 890 | if __name__ == "__main__": | 716 | if __name__ == "__main__": |
| 891 | main() | 717 | main() |
diff --git a/train_ti.py b/train_ti.py index 2497519..48a2333 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed | |||
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
| 19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| @@ -446,15 +446,15 @@ def parse_args(): | |||
| 446 | if isinstance(args.placeholder_tokens, str): | 446 | if isinstance(args.placeholder_tokens, str): |
| 447 | args.placeholder_tokens = [args.placeholder_tokens] | 447 | args.placeholder_tokens = [args.placeholder_tokens] |
| 448 | 448 | ||
| 449 | if len(args.placeholder_tokens) == 0: | ||
| 450 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] | ||
| 451 | |||
| 452 | if isinstance(args.initializer_tokens, str): | 449 | if isinstance(args.initializer_tokens, str): |
| 453 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 450 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
| 454 | 451 | ||
| 455 | if len(args.initializer_tokens) == 0: | 452 | if len(args.initializer_tokens) == 0: |
| 456 | raise ValueError("You must specify --initializer_tokens") | 453 | raise ValueError("You must specify --initializer_tokens") |
| 457 | 454 | ||
| 455 | if len(args.placeholder_tokens) == 0: | ||
| 456 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
| 457 | |||
| 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 458 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 459 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 460 | 460 | ||
| @@ -544,9 +544,6 @@ def main(): | |||
| 544 | args.train_batch_size * accelerator.num_processes | 544 | args.train_batch_size * accelerator.num_processes |
| 545 | ) | 545 | ) |
| 546 | 546 | ||
| 547 | if args.find_lr: | ||
| 548 | args.learning_rate = 1e-5 | ||
| 549 | |||
| 550 | if args.use_8bit_adam: | 547 | if args.use_8bit_adam: |
| 551 | try: | 548 | try: |
| 552 | import bitsandbytes as bnb | 549 | import bitsandbytes as bnb |
| @@ -563,19 +560,6 @@ def main(): | |||
| 563 | elif args.mixed_precision == "bf16": | 560 | elif args.mixed_precision == "bf16": |
| 564 | weight_dtype = torch.bfloat16 | 561 | weight_dtype = torch.bfloat16 |
| 565 | 562 | ||
| 566 | def keyword_filter(item: VlpnDataItem): | ||
| 567 | cond1 = any( | ||
| 568 | keyword in part | ||
| 569 | for keyword in args.placeholder_tokens | ||
| 570 | for part in item.prompt | ||
| 571 | ) | ||
| 572 | cond3 = args.collection is None or args.collection in item.collection | ||
| 573 | cond4 = args.exclude_collections is None or not any( | ||
| 574 | collection in item.collection | ||
| 575 | for collection in args.exclude_collections | ||
| 576 | ) | ||
| 577 | return cond1 and cond3 and cond4 | ||
| 578 | |||
| 579 | datamodule = VlpnDataModule( | 563 | datamodule = VlpnDataModule( |
| 580 | data_file=args.train_data_file, | 564 | data_file=args.train_data_file, |
| 581 | batch_size=args.train_batch_size, | 565 | batch_size=args.train_batch_size, |
| @@ -593,7 +577,7 @@ def main(): | |||
| 593 | valid_set_size=args.valid_set_size, | 577 | valid_set_size=args.valid_set_size, |
| 594 | valid_set_repeat=args.valid_set_repeat, | 578 | valid_set_repeat=args.valid_set_repeat, |
| 595 | seed=args.seed, | 579 | seed=args.seed, |
| 596 | filter=keyword_filter, | 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 597 | dtype=weight_dtype | 581 | dtype=weight_dtype |
| 598 | ) | 582 | ) |
| 599 | datamodule.setup() | 583 | datamodule.setup() |
| @@ -622,8 +606,6 @@ def main(): | |||
| 622 | text_encoder=text_encoder, | 606 | text_encoder=text_encoder, |
| 623 | vae=vae, | 607 | vae=vae, |
| 624 | noise_scheduler=noise_scheduler, | 608 | noise_scheduler=noise_scheduler, |
| 625 | train_dataloader=train_dataloader, | ||
| 626 | val_dataloader=val_dataloader, | ||
| 627 | dtype=weight_dtype, | 609 | dtype=weight_dtype, |
| 628 | seed=args.seed, | 610 | seed=args.seed, |
| 629 | callbacks_fn=textual_inversion_strategy | 611 | callbacks_fn=textual_inversion_strategy |
| @@ -638,25 +620,25 @@ def main(): | |||
| 638 | amsgrad=args.adam_amsgrad, | 620 | amsgrad=args.adam_amsgrad, |
| 639 | ) | 621 | ) |
| 640 | 622 | ||
| 641 | if args.find_lr: | 623 | lr_scheduler = get_scheduler( |
| 642 | lr_scheduler = None | 624 | args.lr_scheduler, |
| 643 | else: | 625 | optimizer=optimizer, |
| 644 | lr_scheduler = get_scheduler( | 626 | num_training_steps_per_epoch=len(train_dataloader), |
| 645 | args.lr_scheduler, | 627 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 646 | optimizer=optimizer, | 628 | min_lr=args.lr_min_lr, |
| 647 | num_training_steps_per_epoch=len(train_dataloader), | 629 | warmup_func=args.lr_warmup_func, |
| 648 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 630 | annealing_func=args.lr_annealing_func, |
| 649 | min_lr=args.lr_min_lr, | 631 | warmup_exp=args.lr_warmup_exp, |
| 650 | warmup_func=args.lr_warmup_func, | 632 | annealing_exp=args.lr_annealing_exp, |
| 651 | annealing_func=args.lr_annealing_func, | 633 | cycles=args.lr_cycles, |
| 652 | warmup_exp=args.lr_warmup_exp, | 634 | train_epochs=args.num_train_epochs, |
| 653 | annealing_exp=args.lr_annealing_exp, | 635 | warmup_epochs=args.lr_warmup_epochs, |
| 654 | cycles=args.lr_cycles, | 636 | ) |
| 655 | train_epochs=args.num_train_epochs, | ||
| 656 | warmup_epochs=args.lr_warmup_epochs, | ||
| 657 | ) | ||
| 658 | 637 | ||
| 659 | trainer( | 638 | trainer( |
| 639 | project="textual_inversion", | ||
| 640 | train_dataloader=train_dataloader, | ||
| 641 | val_dataloader=val_dataloader, | ||
| 660 | optimizer=optimizer, | 642 | optimizer=optimizer, |
| 661 | lr_scheduler=lr_scheduler, | 643 | lr_scheduler=lr_scheduler, |
| 662 | num_train_epochs=args.num_train_epochs, | 644 | num_train_epochs=args.num_train_epochs, |
diff --git a/training/functional.py b/training/functional.py index 5984ffb..f5c111e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -494,10 +494,11 @@ def train( | |||
| 494 | text_encoder: CLIPTextModel, | 494 | text_encoder: CLIPTextModel, |
| 495 | vae: AutoencoderKL, | 495 | vae: AutoencoderKL, |
| 496 | noise_scheduler: DDPMScheduler, | 496 | noise_scheduler: DDPMScheduler, |
| 497 | train_dataloader: DataLoader, | ||
| 498 | val_dataloader: DataLoader, | ||
| 499 | dtype: torch.dtype, | 497 | dtype: torch.dtype, |
| 500 | seed: int, | 498 | seed: int, |
| 499 | project: str, | ||
| 500 | train_dataloader: DataLoader, | ||
| 501 | val_dataloader: DataLoader, | ||
| 501 | optimizer: torch.optim.Optimizer, | 502 | optimizer: torch.optim.Optimizer, |
| 502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 503 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 503 | callbacks_fn: Callable[..., TrainingCallbacks], | 504 | callbacks_fn: Callable[..., TrainingCallbacks], |
| @@ -544,7 +545,7 @@ def train( | |||
| 544 | ) | 545 | ) |
| 545 | 546 | ||
| 546 | if accelerator.is_main_process: | 547 | if accelerator.is_main_process: |
| 547 | accelerator.init_trackers("textual_inversion") | 548 | accelerator.init_trackers(project) |
| 548 | 549 | ||
| 549 | train_loop( | 550 | train_loop( |
| 550 | accelerator=accelerator, | 551 | accelerator=accelerator, |
