diff options
| -rw-r--r-- | data/csv.py | 15 | ||||
| -rw-r--r-- | train_dreambooth.py | 41 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 43 | ||||
| -rw-r--r-- | training/lr.py | 13 |
5 files changed, 89 insertions, 25 deletions
diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): | |||
| 41 | prompt: list[str] | 41 | prompt: list[str] |
| 42 | cprompt: str | 42 | cprompt: str |
| 43 | nprompt: str | 43 | nprompt: str |
| 44 | mode: list[str] | ||
| 44 | 45 | ||
| 45 | 46 | ||
| 46 | class CSVDataModule(): | 47 | class CSVDataModule(): |
| @@ -56,7 +57,6 @@ class CSVDataModule(): | |||
| 56 | dropout: float = 0, | 57 | dropout: float = 0, |
| 57 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
| 58 | center_crop: bool = False, | 59 | center_crop: bool = False, |
| 59 | mode: Optional[str] = None, | ||
| 60 | template_key: str = "template", | 60 | template_key: str = "template", |
| 61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
| 62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
| @@ -81,7 +81,6 @@ class CSVDataModule(): | |||
| 81 | self.repeats = repeats | 81 | self.repeats = repeats |
| 82 | self.dropout = dropout | 82 | self.dropout = dropout |
| 83 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
| 84 | self.mode = mode | ||
| 85 | self.template_key = template_key | 84 | self.template_key = template_key |
| 86 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
| 87 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
| @@ -113,6 +112,7 @@ class CSVDataModule(): | |||
| 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 114 | expansions | 113 | expansions |
| 115 | )), | 114 | )), |
| 115 | item["mode"].split(", ") if "mode" in item else [] | ||
| 116 | ) | 116 | ) |
| 117 | for item in data | 117 | for item in data |
| 118 | ] | 118 | ] |
| @@ -133,6 +133,7 @@ class CSVDataModule(): | |||
| 133 | item.prompt, | 133 | item.prompt, |
| 134 | item.cprompt, | 134 | item.cprompt, |
| 135 | item.nprompt, | 135 | item.nprompt, |
| 136 | item.mode, | ||
| 136 | ) | 137 | ) |
| 137 | for item in items | 138 | for item in items |
| 138 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
| @@ -145,20 +146,12 @@ class CSVDataModule(): | |||
| 145 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 146 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
| 146 | items = metadata["items"] if "items" in metadata else [] | 147 | items = metadata["items"] if "items" in metadata else [] |
| 147 | 148 | ||
| 148 | if self.mode is not None: | ||
| 149 | items = [ | ||
| 150 | item | ||
| 151 | for item in items | ||
| 152 | if "mode" in item and self.mode in item["mode"].split(", ") | ||
| 153 | ] | ||
| 154 | items = self.prepare_items(template, expansions, items) | 149 | items = self.prepare_items(template, expansions, items) |
| 155 | items = self.filter_items(items) | 150 | items = self.filter_items(items) |
| 156 | 151 | ||
| 157 | num_images = len(items) | 152 | num_images = len(items) |
| 158 | 153 | ||
| 159 | valid_set_size = int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) |
| 160 | if self.valid_set_size: | ||
| 161 | valid_set_size = min(valid_set_size, self.valid_set_size) | ||
| 162 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
| 163 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
| 164 | 157 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 202d52c..072150b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | 22 | ||
| 23 | from common import load_text_embeddings, load_config | 23 | from common import load_text_embeddings, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -83,6 +83,18 @@ def parse_args(): | |||
| 83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 84 | ) | 84 | ) |
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--exclude_keywords", | ||
| 87 | type=str, | ||
| 88 | nargs='*', | ||
| 89 | help="Skip dataset items containing a listed keyword.", | ||
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 92 | "--exclude_modes", | ||
| 93 | type=str, | ||
| 94 | nargs='*', | ||
| 95 | help="Exclude all items with a listed mode.", | ||
| 96 | ) | ||
| 97 | parser.add_argument( | ||
| 86 | "--train_text_encoder", | 98 | "--train_text_encoder", |
| 87 | action="store_true", | 99 | action="store_true", |
| 88 | default=True, | 100 | default=True, |
| @@ -379,6 +391,12 @@ def parse_args(): | |||
| 379 | if len(args.placeholder_token) != len(args.initializer_token): | 391 | if len(args.placeholder_token) != len(args.initializer_token): |
| 380 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| 381 | 393 | ||
| 394 | if isinstance(args.exclude_keywords, str): | ||
| 395 | args.exclude_keywords = [args.exclude_keywords] | ||
| 396 | |||
| 397 | if isinstance(args.exclude_modes, str): | ||
| 398 | args.exclude_modes = [args.exclude_modes] | ||
| 399 | |||
| 382 | if args.output_dir is None: | 400 | if args.output_dir is None: |
| 383 | raise ValueError("You must specify --output_dir") | 401 | raise ValueError("You must specify --output_dir") |
| 384 | 402 | ||
| @@ -636,6 +654,19 @@ def main(): | |||
| 636 | elif args.mixed_precision == "bf16": | 654 | elif args.mixed_precision == "bf16": |
| 637 | weight_dtype = torch.bfloat16 | 655 | weight_dtype = torch.bfloat16 |
| 638 | 656 | ||
| 657 | def keyword_filter(item: CSVDataItem): | ||
| 658 | cond2 = args.exclude_keywords is None or not any( | ||
| 659 | keyword in part | ||
| 660 | for keyword in args.exclude_keywords | ||
| 661 | for part in item.prompt | ||
| 662 | ) | ||
| 663 | cond3 = args.mode is None or args.mode in item.mode | ||
| 664 | cond4 = args.exclude_modes is None or not any( | ||
| 665 | mode in item.mode | ||
| 666 | for mode in args.exclude_modes | ||
| 667 | ) | ||
| 668 | return cond2 and cond3 and cond4 | ||
| 669 | |||
| 639 | def collate_fn(examples): | 670 | def collate_fn(examples): |
| 640 | prompts = [example["prompts"] for example in examples] | 671 | prompts = [example["prompts"] for example in examples] |
| 641 | cprompts = [example["cprompts"] for example in examples] | 672 | cprompts = [example["cprompts"] for example in examples] |
| @@ -671,12 +702,12 @@ def main(): | |||
| 671 | num_class_images=args.num_class_images, | 702 | num_class_images=args.num_class_images, |
| 672 | size=args.resolution, | 703 | size=args.resolution, |
| 673 | repeats=args.repeats, | 704 | repeats=args.repeats, |
| 674 | mode=args.mode, | ||
| 675 | dropout=args.tag_dropout, | 705 | dropout=args.tag_dropout, |
| 676 | center_crop=args.center_crop, | 706 | center_crop=args.center_crop, |
| 677 | template_key=args.train_data_template, | 707 | template_key=args.train_data_template, |
| 678 | valid_set_size=args.valid_set_size, | 708 | valid_set_size=args.valid_set_size, |
| 679 | num_workers=args.dataloader_num_workers, | 709 | num_workers=args.dataloader_num_workers, |
| 710 | filter=keyword_filter, | ||
| 680 | collate_fn=collate_fn | 711 | collate_fn=collate_fn |
| 681 | ) | 712 | ) |
| 682 | 713 | ||
| @@ -782,6 +813,10 @@ def main(): | |||
| 782 | config = vars(args).copy() | 813 | config = vars(args).copy() |
| 783 | config["initializer_token"] = " ".join(config["initializer_token"]) | 814 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 784 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 816 | if config["exclude_modes"] is not None: | ||
| 817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
| 818 | if config["exclude_keywords"] is not None: | ||
| 819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
| 785 | accelerator.init_trackers("dreambooth", config=config) | 820 | accelerator.init_trackers("dreambooth", config=config) |
| 786 | 821 | ||
| 787 | # Train! | 822 | # Train! |
| @@ -879,7 +914,7 @@ def main(): | |||
| 879 | target, target_prior = torch.chunk(target, 2, dim=0) | 914 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 880 | 915 | ||
| 881 | # Compute instance loss | 916 | # Compute instance loss |
| 882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 917 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 883 | 918 | ||
| 884 | # Compute prior loss | 919 | # Compute prior loss |
| 885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 920 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -810,7 +810,7 @@ def main(): | |||
| 810 | target, target_prior = torch.chunk(target, 2, dim=0) | 810 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 811 | 811 | ||
| 812 | # Compute instance loss | 812 | # Compute instance loss |
| 813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 814 | 814 | ||
| 815 | # Compute prior loss | 815 | # Compute prior loss |
| 816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
diff --git a/train_ti.py b/train_ti.py index b1f6a49..6aa4007 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -93,6 +93,18 @@ def parse_args(): | |||
| 93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
| 94 | ) | 94 | ) |
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--exclude_keywords", | ||
| 97 | type=str, | ||
| 98 | nargs='*', | ||
| 99 | help="Skip dataset items containing a listed keyword.", | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 102 | "--exclude_modes", | ||
| 103 | type=str, | ||
| 104 | nargs='*', | ||
| 105 | help="Exclude all items with a listed mode.", | ||
| 106 | ) | ||
| 107 | parser.add_argument( | ||
| 96 | "--repeats", | 108 | "--repeats", |
| 97 | type=int, | 109 | type=int, |
| 98 | default=1, | 110 | default=1, |
| @@ -120,7 +132,8 @@ def parse_args(): | |||
| 120 | "--seed", | 132 | "--seed", |
| 121 | type=int, | 133 | type=int, |
| 122 | default=None, | 134 | default=None, |
| 123 | help="A seed for reproducible training.") | 135 | help="A seed for reproducible training." |
| 136 | ) | ||
| 124 | parser.add_argument( | 137 | parser.add_argument( |
| 125 | "--resolution", | 138 | "--resolution", |
| 126 | type=int, | 139 | type=int, |
| @@ -356,6 +369,12 @@ def parse_args(): | |||
| 356 | if len(args.placeholder_token) != len(args.initializer_token): | 369 | if len(args.placeholder_token) != len(args.initializer_token): |
| 357 | raise ValueError("You must specify --placeholder_token") | 370 | raise ValueError("You must specify --placeholder_token") |
| 358 | 371 | ||
| 372 | if isinstance(args.exclude_keywords, str): | ||
| 373 | args.exclude_keywords = [args.exclude_keywords] | ||
| 374 | |||
| 375 | if isinstance(args.exclude_modes, str): | ||
| 376 | args.exclude_modes = [args.exclude_modes] | ||
| 377 | |||
| 359 | if args.output_dir is None: | 378 | if args.output_dir is None: |
| 360 | raise ValueError("You must specify --output_dir") | 379 | raise ValueError("You must specify --output_dir") |
| 361 | 380 | ||
| @@ -576,11 +595,22 @@ def main(): | |||
| 576 | weight_dtype = torch.bfloat16 | 595 | weight_dtype = torch.bfloat16 |
| 577 | 596 | ||
| 578 | def keyword_filter(item: CSVDataItem): | 597 | def keyword_filter(item: CSVDataItem): |
| 579 | return any( | 598 | cond1 = any( |
| 580 | keyword in part | 599 | keyword in part |
| 581 | for keyword in args.placeholder_token | 600 | for keyword in args.placeholder_token |
| 582 | for part in item.prompt | 601 | for part in item.prompt |
| 583 | ) | 602 | ) |
| 603 | cond2 = args.exclude_keywords is None or not any( | ||
| 604 | keyword in part | ||
| 605 | for keyword in args.exclude_keywords | ||
| 606 | for part in item.prompt | ||
| 607 | ) | ||
| 608 | cond3 = args.mode is None or args.mode in item.mode | ||
| 609 | cond4 = args.exclude_modes is None or not any( | ||
| 610 | mode in item.mode | ||
| 611 | for mode in args.exclude_modes | ||
| 612 | ) | ||
| 613 | return cond1 and cond2 and cond3 and cond4 | ||
| 584 | 614 | ||
| 585 | def collate_fn(examples): | 615 | def collate_fn(examples): |
| 586 | prompts = [example["prompts"] for example in examples] | 616 | prompts = [example["prompts"] for example in examples] |
| @@ -617,7 +647,6 @@ def main(): | |||
| 617 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
| 618 | size=args.resolution, | 648 | size=args.resolution, |
| 619 | repeats=args.repeats, | 649 | repeats=args.repeats, |
| 620 | mode=args.mode, | ||
| 621 | dropout=args.tag_dropout, | 650 | dropout=args.tag_dropout, |
| 622 | center_crop=args.center_crop, | 651 | center_crop=args.center_crop, |
| 623 | template_key=args.train_data_template, | 652 | template_key=args.train_data_template, |
| @@ -769,7 +798,7 @@ def main(): | |||
| 769 | target, target_prior = torch.chunk(target, 2, dim=0) | 798 | target, target_prior = torch.chunk(target, 2, dim=0) |
| 770 | 799 | ||
| 771 | # Compute instance loss | 800 | # Compute instance loss |
| 772 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 801 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 773 | 802 | ||
| 774 | # Compute prior loss | 803 | # Compute prior loss |
| 775 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 804 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
| @@ -785,7 +814,7 @@ def main(): | |||
| 785 | 814 | ||
| 786 | if args.find_lr: | 815 | if args.find_lr: |
| 787 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 816 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 788 | lr_finder.run(min_lr=1e-6, num_train_batches=4) | 817 | lr_finder.run(min_lr=1e-6, num_train_batches=1) |
| 789 | 818 | ||
| 790 | plt.savefig(basepath.joinpath("lr.png")) | 819 | plt.savefig(basepath.joinpath("lr.png")) |
| 791 | plt.close() | 820 | plt.close() |
| @@ -798,6 +827,10 @@ def main(): | |||
| 798 | config = vars(args).copy() | 827 | config = vars(args).copy() |
| 799 | config["initializer_token"] = " ".join(config["initializer_token"]) | 828 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 800 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 830 | if config["exclude_modes"] is not None: | ||
| 831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
| 832 | if config["exclude_keywords"] is not None: | ||
| 833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
| 801 | accelerator.init_trackers("textual_inversion", config=config) | 834 | accelerator.init_trackers("textual_inversion", config=config) |
| 802 | 835 | ||
| 803 | # Train! | 836 | # Train! |
diff --git a/training/lr.py b/training/lr.py index ef01906..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -43,9 +43,6 @@ class LRFinder(): | |||
| 43 | ) | 43 | ) |
| 44 | progress_bar.set_description("Epoch X / Y") | 44 | progress_bar.set_description("Epoch X / Y") |
| 45 | 45 | ||
| 46 | train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] | ||
| 47 | val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] | ||
| 48 | |||
| 49 | for epoch in range(num_epochs): | 46 | for epoch in range(num_epochs): |
| 50 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 51 | 48 | ||
| @@ -54,7 +51,10 @@ class LRFinder(): | |||
| 54 | 51 | ||
| 55 | self.model.train() | 52 | self.model.train() |
| 56 | 53 | ||
| 57 | for batch in train_workload: | 54 | for step, batch in enumerate(self.train_dataloader): |
| 55 | if step >= num_train_batches: | ||
| 56 | break | ||
| 57 | |||
| 58 | with self.accelerator.accumulate(self.model): | 58 | with self.accelerator.accumulate(self.model): |
| 59 | loss, acc, bsz = self.loss_fn(batch) | 59 | loss, acc, bsz = self.loss_fn(batch) |
| 60 | 60 | ||
| @@ -69,7 +69,10 @@ class LRFinder(): | |||
| 69 | self.model.eval() | 69 | self.model.eval() |
| 70 | 70 | ||
| 71 | with torch.inference_mode(): | 71 | with torch.inference_mode(): |
| 72 | for batch in val_workload: | 72 | for step, batch in enumerate(self.val_dataloader): |
| 73 | if step >= num_val_batches: | ||
| 74 | break | ||
| 75 | |||
| 73 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
| 74 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
| 75 | avg_acc.update(acc.detach_(), bsz) | 78 | avg_acc.update(acc.detach_(), bsz) |
