diff options
author | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-30 13:48:26 +0100 |
commit | dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch) | |
tree | da07cbadfad6f54e55e43e2fda21cef80cded5ea | |
parent | Update (diff) | |
download | textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2 textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip |
Training script improvements
-rw-r--r-- | data/csv.py | 15 | ||||
-rw-r--r-- | train_dreambooth.py | 41 | ||||
-rw-r--r-- | train_lora.py | 2 | ||||
-rw-r--r-- | train_ti.py | 43 | ||||
-rw-r--r-- | training/lr.py | 13 |
5 files changed, 89 insertions, 25 deletions
diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): | |||
41 | prompt: list[str] | 41 | prompt: list[str] |
42 | cprompt: str | 42 | cprompt: str |
43 | nprompt: str | 43 | nprompt: str |
44 | mode: list[str] | ||
44 | 45 | ||
45 | 46 | ||
46 | class CSVDataModule(): | 47 | class CSVDataModule(): |
@@ -56,7 +57,6 @@ class CSVDataModule(): | |||
56 | dropout: float = 0, | 57 | dropout: float = 0, |
57 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
58 | center_crop: bool = False, | 59 | center_crop: bool = False, |
59 | mode: Optional[str] = None, | ||
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
@@ -81,7 +81,6 @@ class CSVDataModule(): | |||
81 | self.repeats = repeats | 81 | self.repeats = repeats |
82 | self.dropout = dropout | 82 | self.dropout = dropout |
83 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
84 | self.mode = mode | ||
85 | self.template_key = template_key | 84 | self.template_key = template_key |
86 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
87 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
@@ -113,6 +112,7 @@ class CSVDataModule(): | |||
113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
114 | expansions | 113 | expansions |
115 | )), | 114 | )), |
115 | item["mode"].split(", ") if "mode" in item else [] | ||
116 | ) | 116 | ) |
117 | for item in data | 117 | for item in data |
118 | ] | 118 | ] |
@@ -133,6 +133,7 @@ class CSVDataModule(): | |||
133 | item.prompt, | 133 | item.prompt, |
134 | item.cprompt, | 134 | item.cprompt, |
135 | item.nprompt, | 135 | item.nprompt, |
136 | item.mode, | ||
136 | ) | 137 | ) |
137 | for item in items | 138 | for item in items |
138 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
@@ -145,20 +146,12 @@ class CSVDataModule(): | |||
145 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 146 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
146 | items = metadata["items"] if "items" in metadata else [] | 147 | items = metadata["items"] if "items" in metadata else [] |
147 | 148 | ||
148 | if self.mode is not None: | ||
149 | items = [ | ||
150 | item | ||
151 | for item in items | ||
152 | if "mode" in item and self.mode in item["mode"].split(", ") | ||
153 | ] | ||
154 | items = self.prepare_items(template, expansions, items) | 149 | items = self.prepare_items(template, expansions, items) |
155 | items = self.filter_items(items) | 150 | items = self.filter_items(items) |
156 | 151 | ||
157 | num_images = len(items) | 152 | num_images = len(items) |
158 | 153 | ||
159 | valid_set_size = int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) |
160 | if self.valid_set_size: | ||
161 | valid_set_size = min(valid_set_size, self.valid_set_size) | ||
162 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
163 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
164 | 157 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 202d52c..072150b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | 22 | ||
23 | from common import load_text_embeddings, load_config | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -83,6 +83,18 @@ def parse_args(): | |||
83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--exclude_keywords", | ||
87 | type=str, | ||
88 | nargs='*', | ||
89 | help="Skip dataset items containing a listed keyword.", | ||
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--exclude_modes", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="Exclude all items with a listed mode.", | ||
96 | ) | ||
97 | parser.add_argument( | ||
86 | "--train_text_encoder", | 98 | "--train_text_encoder", |
87 | action="store_true", | 99 | action="store_true", |
88 | default=True, | 100 | default=True, |
@@ -379,6 +391,12 @@ def parse_args(): | |||
379 | if len(args.placeholder_token) != len(args.initializer_token): | 391 | if len(args.placeholder_token) != len(args.initializer_token): |
380 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
381 | 393 | ||
394 | if isinstance(args.exclude_keywords, str): | ||
395 | args.exclude_keywords = [args.exclude_keywords] | ||
396 | |||
397 | if isinstance(args.exclude_modes, str): | ||
398 | args.exclude_modes = [args.exclude_modes] | ||
399 | |||
382 | if args.output_dir is None: | 400 | if args.output_dir is None: |
383 | raise ValueError("You must specify --output_dir") | 401 | raise ValueError("You must specify --output_dir") |
384 | 402 | ||
@@ -636,6 +654,19 @@ def main(): | |||
636 | elif args.mixed_precision == "bf16": | 654 | elif args.mixed_precision == "bf16": |
637 | weight_dtype = torch.bfloat16 | 655 | weight_dtype = torch.bfloat16 |
638 | 656 | ||
657 | def keyword_filter(item: CSVDataItem): | ||
658 | cond2 = args.exclude_keywords is None or not any( | ||
659 | keyword in part | ||
660 | for keyword in args.exclude_keywords | ||
661 | for part in item.prompt | ||
662 | ) | ||
663 | cond3 = args.mode is None or args.mode in item.mode | ||
664 | cond4 = args.exclude_modes is None or not any( | ||
665 | mode in item.mode | ||
666 | for mode in args.exclude_modes | ||
667 | ) | ||
668 | return cond2 and cond3 and cond4 | ||
669 | |||
639 | def collate_fn(examples): | 670 | def collate_fn(examples): |
640 | prompts = [example["prompts"] for example in examples] | 671 | prompts = [example["prompts"] for example in examples] |
641 | cprompts = [example["cprompts"] for example in examples] | 672 | cprompts = [example["cprompts"] for example in examples] |
@@ -671,12 +702,12 @@ def main(): | |||
671 | num_class_images=args.num_class_images, | 702 | num_class_images=args.num_class_images, |
672 | size=args.resolution, | 703 | size=args.resolution, |
673 | repeats=args.repeats, | 704 | repeats=args.repeats, |
674 | mode=args.mode, | ||
675 | dropout=args.tag_dropout, | 705 | dropout=args.tag_dropout, |
676 | center_crop=args.center_crop, | 706 | center_crop=args.center_crop, |
677 | template_key=args.train_data_template, | 707 | template_key=args.train_data_template, |
678 | valid_set_size=args.valid_set_size, | 708 | valid_set_size=args.valid_set_size, |
679 | num_workers=args.dataloader_num_workers, | 709 | num_workers=args.dataloader_num_workers, |
710 | filter=keyword_filter, | ||
680 | collate_fn=collate_fn | 711 | collate_fn=collate_fn |
681 | ) | 712 | ) |
682 | 713 | ||
@@ -782,6 +813,10 @@ def main(): | |||
782 | config = vars(args).copy() | 813 | config = vars(args).copy() |
783 | config["initializer_token"] = " ".join(config["initializer_token"]) | 814 | config["initializer_token"] = " ".join(config["initializer_token"]) |
784 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
816 | if config["exclude_modes"] is not None: | ||
817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
818 | if config["exclude_keywords"] is not None: | ||
819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
785 | accelerator.init_trackers("dreambooth", config=config) | 820 | accelerator.init_trackers("dreambooth", config=config) |
786 | 821 | ||
787 | # Train! | 822 | # Train! |
@@ -879,7 +914,7 @@ def main(): | |||
879 | target, target_prior = torch.chunk(target, 2, dim=0) | 914 | target, target_prior = torch.chunk(target, 2, dim=0) |
880 | 915 | ||
881 | # Compute instance loss | 916 | # Compute instance loss |
882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 917 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
883 | 918 | ||
884 | # Compute prior loss | 919 | # Compute prior loss |
885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 920 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -810,7 +810,7 @@ def main(): | |||
810 | target, target_prior = torch.chunk(target, 2, dim=0) | 810 | target, target_prior = torch.chunk(target, 2, dim=0) |
811 | 811 | ||
812 | # Compute instance loss | 812 | # Compute instance loss |
813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 813 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
814 | 814 | ||
815 | # Compute prior loss | 815 | # Compute prior loss |
816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 816 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
diff --git a/train_ti.py b/train_ti.py index b1f6a49..6aa4007 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -93,6 +93,18 @@ def parse_args(): | |||
93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
94 | ) | 94 | ) |
95 | parser.add_argument( | 95 | parser.add_argument( |
96 | "--exclude_keywords", | ||
97 | type=str, | ||
98 | nargs='*', | ||
99 | help="Skip dataset items containing a listed keyword.", | ||
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--exclude_modes", | ||
103 | type=str, | ||
104 | nargs='*', | ||
105 | help="Exclude all items with a listed mode.", | ||
106 | ) | ||
107 | parser.add_argument( | ||
96 | "--repeats", | 108 | "--repeats", |
97 | type=int, | 109 | type=int, |
98 | default=1, | 110 | default=1, |
@@ -120,7 +132,8 @@ def parse_args(): | |||
120 | "--seed", | 132 | "--seed", |
121 | type=int, | 133 | type=int, |
122 | default=None, | 134 | default=None, |
123 | help="A seed for reproducible training.") | 135 | help="A seed for reproducible training." |
136 | ) | ||
124 | parser.add_argument( | 137 | parser.add_argument( |
125 | "--resolution", | 138 | "--resolution", |
126 | type=int, | 139 | type=int, |
@@ -356,6 +369,12 @@ def parse_args(): | |||
356 | if len(args.placeholder_token) != len(args.initializer_token): | 369 | if len(args.placeholder_token) != len(args.initializer_token): |
357 | raise ValueError("You must specify --placeholder_token") | 370 | raise ValueError("You must specify --placeholder_token") |
358 | 371 | ||
372 | if isinstance(args.exclude_keywords, str): | ||
373 | args.exclude_keywords = [args.exclude_keywords] | ||
374 | |||
375 | if isinstance(args.exclude_modes, str): | ||
376 | args.exclude_modes = [args.exclude_modes] | ||
377 | |||
359 | if args.output_dir is None: | 378 | if args.output_dir is None: |
360 | raise ValueError("You must specify --output_dir") | 379 | raise ValueError("You must specify --output_dir") |
361 | 380 | ||
@@ -576,11 +595,22 @@ def main(): | |||
576 | weight_dtype = torch.bfloat16 | 595 | weight_dtype = torch.bfloat16 |
577 | 596 | ||
578 | def keyword_filter(item: CSVDataItem): | 597 | def keyword_filter(item: CSVDataItem): |
579 | return any( | 598 | cond1 = any( |
580 | keyword in part | 599 | keyword in part |
581 | for keyword in args.placeholder_token | 600 | for keyword in args.placeholder_token |
582 | for part in item.prompt | 601 | for part in item.prompt |
583 | ) | 602 | ) |
603 | cond2 = args.exclude_keywords is None or not any( | ||
604 | keyword in part | ||
605 | for keyword in args.exclude_keywords | ||
606 | for part in item.prompt | ||
607 | ) | ||
608 | cond3 = args.mode is None or args.mode in item.mode | ||
609 | cond4 = args.exclude_modes is None or not any( | ||
610 | mode in item.mode | ||
611 | for mode in args.exclude_modes | ||
612 | ) | ||
613 | return cond1 and cond2 and cond3 and cond4 | ||
584 | 614 | ||
585 | def collate_fn(examples): | 615 | def collate_fn(examples): |
586 | prompts = [example["prompts"] for example in examples] | 616 | prompts = [example["prompts"] for example in examples] |
@@ -617,7 +647,6 @@ def main(): | |||
617 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
618 | size=args.resolution, | 648 | size=args.resolution, |
619 | repeats=args.repeats, | 649 | repeats=args.repeats, |
620 | mode=args.mode, | ||
621 | dropout=args.tag_dropout, | 650 | dropout=args.tag_dropout, |
622 | center_crop=args.center_crop, | 651 | center_crop=args.center_crop, |
623 | template_key=args.train_data_template, | 652 | template_key=args.train_data_template, |
@@ -769,7 +798,7 @@ def main(): | |||
769 | target, target_prior = torch.chunk(target, 2, dim=0) | 798 | target, target_prior = torch.chunk(target, 2, dim=0) |
770 | 799 | ||
771 | # Compute instance loss | 800 | # Compute instance loss |
772 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() | 801 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
773 | 802 | ||
774 | # Compute prior loss | 803 | # Compute prior loss |
775 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | 804 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") |
@@ -785,7 +814,7 @@ def main(): | |||
785 | 814 | ||
786 | if args.find_lr: | 815 | if args.find_lr: |
787 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 816 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
788 | lr_finder.run(min_lr=1e-6, num_train_batches=4) | 817 | lr_finder.run(min_lr=1e-6, num_train_batches=1) |
789 | 818 | ||
790 | plt.savefig(basepath.joinpath("lr.png")) | 819 | plt.savefig(basepath.joinpath("lr.png")) |
791 | plt.close() | 820 | plt.close() |
@@ -798,6 +827,10 @@ def main(): | |||
798 | config = vars(args).copy() | 827 | config = vars(args).copy() |
799 | config["initializer_token"] = " ".join(config["initializer_token"]) | 828 | config["initializer_token"] = " ".join(config["initializer_token"]) |
800 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
830 | if config["exclude_modes"] is not None: | ||
831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | ||
832 | if config["exclude_keywords"] is not None: | ||
833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | ||
801 | accelerator.init_trackers("textual_inversion", config=config) | 834 | accelerator.init_trackers("textual_inversion", config=config) |
802 | 835 | ||
803 | # Train! | 836 | # Train! |
diff --git a/training/lr.py b/training/lr.py index ef01906..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -43,9 +43,6 @@ class LRFinder(): | |||
43 | ) | 43 | ) |
44 | progress_bar.set_description("Epoch X / Y") | 44 | progress_bar.set_description("Epoch X / Y") |
45 | 45 | ||
46 | train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] | ||
47 | val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] | ||
48 | |||
49 | for epoch in range(num_epochs): | 46 | for epoch in range(num_epochs): |
50 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 47 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
51 | 48 | ||
@@ -54,7 +51,10 @@ class LRFinder(): | |||
54 | 51 | ||
55 | self.model.train() | 52 | self.model.train() |
56 | 53 | ||
57 | for batch in train_workload: | 54 | for step, batch in enumerate(self.train_dataloader): |
55 | if step >= num_train_batches: | ||
56 | break | ||
57 | |||
58 | with self.accelerator.accumulate(self.model): | 58 | with self.accelerator.accumulate(self.model): |
59 | loss, acc, bsz = self.loss_fn(batch) | 59 | loss, acc, bsz = self.loss_fn(batch) |
60 | 60 | ||
@@ -69,7 +69,10 @@ class LRFinder(): | |||
69 | self.model.eval() | 69 | self.model.eval() |
70 | 70 | ||
71 | with torch.inference_mode(): | 71 | with torch.inference_mode(): |
72 | for batch in val_workload: | 72 | for step, batch in enumerate(self.val_dataloader): |
73 | if step >= num_val_batches: | ||
74 | break | ||
75 | |||
73 | loss, acc, bsz = self.loss_fn(batch) | 76 | loss, acc, bsz = self.loss_fn(batch) |
74 | avg_loss.update(loss.detach_(), bsz) | 77 | avg_loss.update(loss.detach_(), bsz) |
75 | avg_acc.update(acc.detach_(), bsz) | 78 | avg_acc.update(acc.detach_(), bsz) |