summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
commitdfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch)
treeda07cbadfad6f54e55e43e2fda21cef80cded5ea
parentUpdate (diff)
downloadtextual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip
Training script improvements
-rw-r--r--data/csv.py15
-rw-r--r--train_dreambooth.py41
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py43
-rw-r--r--training/lr.py13
5 files changed, 89 insertions, 25 deletions
diff --git a/data/csv.py b/data/csv.py
index 0ad36dc..4da5d64 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple):
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: str 42 cprompt: str
43 nprompt: str 43 nprompt: str
44 mode: list[str]
44 45
45 46
46class CSVDataModule(): 47class CSVDataModule():
@@ -56,7 +57,6 @@ class CSVDataModule():
56 dropout: float = 0, 57 dropout: float = 0,
57 interpolation: str = "bicubic", 58 interpolation: str = "bicubic",
58 center_crop: bool = False, 59 center_crop: bool = False,
59 mode: Optional[str] = None,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 generator: Optional[torch.Generator] = None,
@@ -81,7 +81,6 @@ class CSVDataModule():
81 self.repeats = repeats 81 self.repeats = repeats
82 self.dropout = dropout 82 self.dropout = dropout
83 self.center_crop = center_crop 83 self.center_crop = center_crop
84 self.mode = mode
85 self.template_key = template_key 84 self.template_key = template_key
86 self.interpolation = interpolation 85 self.interpolation = interpolation
87 self.valid_set_size = valid_set_size 86 self.valid_set_size = valid_set_size
@@ -113,6 +112,7 @@ class CSVDataModule():
113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions 113 expansions
115 )), 114 )),
115 item["mode"].split(", ") if "mode" in item else []
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -133,6 +133,7 @@ class CSVDataModule():
133 item.prompt, 133 item.prompt,
134 item.cprompt, 134 item.cprompt,
135 item.nprompt, 135 item.nprompt,
136 item.mode,
136 ) 137 )
137 for item in items 138 for item in items
138 for i in range(image_multiplier) 139 for i in range(image_multiplier)
@@ -145,20 +146,12 @@ class CSVDataModule():
145 expansions = metadata["expansions"] if "expansions" in metadata else {} 146 expansions = metadata["expansions"] if "expansions" in metadata else {}
146 items = metadata["items"] if "items" in metadata else [] 147 items = metadata["items"] if "items" in metadata else []
147 148
148 if self.mode is not None:
149 items = [
150 item
151 for item in items
152 if "mode" in item and self.mode in item["mode"].split(", ")
153 ]
154 items = self.prepare_items(template, expansions, items) 149 items = self.prepare_items(template, expansions, items)
155 items = self.filter_items(items) 150 items = self.filter_items(items)
156 151
157 num_images = len(items) 152 num_images = len(items)
158 153
159 valid_set_size = int(num_images * 0.1) 154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1)
160 if self.valid_set_size:
161 valid_set_size = min(valid_set_size, self.valid_set_size)
162 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
163 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
164 157
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 202d52c..072150b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -22,7 +22,7 @@ from slugify import slugify
22 22
23from common import load_text_embeddings, load_config 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
@@ -83,6 +83,18 @@ def parse_args():
83 help="A token to use as initializer word." 83 help="A token to use as initializer word."
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--exclude_keywords",
87 type=str,
88 nargs='*',
89 help="Skip dataset items containing a listed keyword.",
90 )
91 parser.add_argument(
92 "--exclude_modes",
93 type=str,
94 nargs='*',
95 help="Exclude all items with a listed mode.",
96 )
97 parser.add_argument(
86 "--train_text_encoder", 98 "--train_text_encoder",
87 action="store_true", 99 action="store_true",
88 default=True, 100 default=True,
@@ -379,6 +391,12 @@ def parse_args():
379 if len(args.placeholder_token) != len(args.initializer_token): 391 if len(args.placeholder_token) != len(args.initializer_token):
380 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 392 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
381 393
394 if isinstance(args.exclude_keywords, str):
395 args.exclude_keywords = [args.exclude_keywords]
396
397 if isinstance(args.exclude_modes, str):
398 args.exclude_modes = [args.exclude_modes]
399
382 if args.output_dir is None: 400 if args.output_dir is None:
383 raise ValueError("You must specify --output_dir") 401 raise ValueError("You must specify --output_dir")
384 402
@@ -636,6 +654,19 @@ def main():
636 elif args.mixed_precision == "bf16": 654 elif args.mixed_precision == "bf16":
637 weight_dtype = torch.bfloat16 655 weight_dtype = torch.bfloat16
638 656
657 def keyword_filter(item: CSVDataItem):
658 cond2 = args.exclude_keywords is None or not any(
659 keyword in part
660 for keyword in args.exclude_keywords
661 for part in item.prompt
662 )
663 cond3 = args.mode is None or args.mode in item.mode
664 cond4 = args.exclude_modes is None or not any(
665 mode in item.mode
666 for mode in args.exclude_modes
667 )
668 return cond2 and cond3 and cond4
669
639 def collate_fn(examples): 670 def collate_fn(examples):
640 prompts = [example["prompts"] for example in examples] 671 prompts = [example["prompts"] for example in examples]
641 cprompts = [example["cprompts"] for example in examples] 672 cprompts = [example["cprompts"] for example in examples]
@@ -671,12 +702,12 @@ def main():
671 num_class_images=args.num_class_images, 702 num_class_images=args.num_class_images,
672 size=args.resolution, 703 size=args.resolution,
673 repeats=args.repeats, 704 repeats=args.repeats,
674 mode=args.mode,
675 dropout=args.tag_dropout, 705 dropout=args.tag_dropout,
676 center_crop=args.center_crop, 706 center_crop=args.center_crop,
677 template_key=args.train_data_template, 707 template_key=args.train_data_template,
678 valid_set_size=args.valid_set_size, 708 valid_set_size=args.valid_set_size,
679 num_workers=args.dataloader_num_workers, 709 num_workers=args.dataloader_num_workers,
710 filter=keyword_filter,
680 collate_fn=collate_fn 711 collate_fn=collate_fn
681 ) 712 )
682 713
@@ -782,6 +813,10 @@ def main():
782 config = vars(args).copy() 813 config = vars(args).copy()
783 config["initializer_token"] = " ".join(config["initializer_token"]) 814 config["initializer_token"] = " ".join(config["initializer_token"])
784 config["placeholder_token"] = " ".join(config["placeholder_token"]) 815 config["placeholder_token"] = " ".join(config["placeholder_token"])
816 if config["exclude_modes"] is not None:
817 config["exclude_modes"] = " ".join(config["exclude_modes"])
818 if config["exclude_keywords"] is not None:
819 config["exclude_keywords"] = " ".join(config["exclude_keywords"])
785 accelerator.init_trackers("dreambooth", config=config) 820 accelerator.init_trackers("dreambooth", config=config)
786 821
787 # Train! 822 # Train!
@@ -879,7 +914,7 @@ def main():
879 target, target_prior = torch.chunk(target, 2, dim=0) 914 target, target_prior = torch.chunk(target, 2, dim=0)
880 915
881 # Compute instance loss 916 # Compute instance loss
882 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 917 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
883 918
884 # Compute prior loss 919 # Compute prior loss
885 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 920 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
diff --git a/train_lora.py b/train_lora.py
index 9a42cae..de878a4 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -810,7 +810,7 @@ def main():
810 target, target_prior = torch.chunk(target, 2, dim=0) 810 target, target_prior = torch.chunk(target, 2, dim=0)
811 811
812 # Compute instance loss 812 # Compute instance loss
813 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 813 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
814 814
815 # Compute prior loss 815 # Compute prior loss
816 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 816 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
diff --git a/train_ti.py b/train_ti.py
index b1f6a49..6aa4007 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -93,6 +93,18 @@ def parse_args():
93 help="The directory where class images will be saved.", 93 help="The directory where class images will be saved.",
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--exclude_keywords",
97 type=str,
98 nargs='*',
99 help="Skip dataset items containing a listed keyword.",
100 )
101 parser.add_argument(
102 "--exclude_modes",
103 type=str,
104 nargs='*',
105 help="Exclude all items with a listed mode.",
106 )
107 parser.add_argument(
96 "--repeats", 108 "--repeats",
97 type=int, 109 type=int,
98 default=1, 110 default=1,
@@ -120,7 +132,8 @@ def parse_args():
120 "--seed", 132 "--seed",
121 type=int, 133 type=int,
122 default=None, 134 default=None,
123 help="A seed for reproducible training.") 135 help="A seed for reproducible training."
136 )
124 parser.add_argument( 137 parser.add_argument(
125 "--resolution", 138 "--resolution",
126 type=int, 139 type=int,
@@ -356,6 +369,12 @@ def parse_args():
356 if len(args.placeholder_token) != len(args.initializer_token): 369 if len(args.placeholder_token) != len(args.initializer_token):
357 raise ValueError("You must specify --placeholder_token") 370 raise ValueError("You must specify --placeholder_token")
358 371
372 if isinstance(args.exclude_keywords, str):
373 args.exclude_keywords = [args.exclude_keywords]
374
375 if isinstance(args.exclude_modes, str):
376 args.exclude_modes = [args.exclude_modes]
377
359 if args.output_dir is None: 378 if args.output_dir is None:
360 raise ValueError("You must specify --output_dir") 379 raise ValueError("You must specify --output_dir")
361 380
@@ -576,11 +595,22 @@ def main():
576 weight_dtype = torch.bfloat16 595 weight_dtype = torch.bfloat16
577 596
578 def keyword_filter(item: CSVDataItem): 597 def keyword_filter(item: CSVDataItem):
579 return any( 598 cond1 = any(
580 keyword in part 599 keyword in part
581 for keyword in args.placeholder_token 600 for keyword in args.placeholder_token
582 for part in item.prompt 601 for part in item.prompt
583 ) 602 )
603 cond2 = args.exclude_keywords is None or not any(
604 keyword in part
605 for keyword in args.exclude_keywords
606 for part in item.prompt
607 )
608 cond3 = args.mode is None or args.mode in item.mode
609 cond4 = args.exclude_modes is None or not any(
610 mode in item.mode
611 for mode in args.exclude_modes
612 )
613 return cond1 and cond2 and cond3 and cond4
584 614
585 def collate_fn(examples): 615 def collate_fn(examples):
586 prompts = [example["prompts"] for example in examples] 616 prompts = [example["prompts"] for example in examples]
@@ -617,7 +647,6 @@ def main():
617 num_class_images=args.num_class_images, 647 num_class_images=args.num_class_images,
618 size=args.resolution, 648 size=args.resolution,
619 repeats=args.repeats, 649 repeats=args.repeats,
620 mode=args.mode,
621 dropout=args.tag_dropout, 650 dropout=args.tag_dropout,
622 center_crop=args.center_crop, 651 center_crop=args.center_crop,
623 template_key=args.train_data_template, 652 template_key=args.train_data_template,
@@ -769,7 +798,7 @@ def main():
769 target, target_prior = torch.chunk(target, 2, dim=0) 798 target, target_prior = torch.chunk(target, 2, dim=0)
770 799
771 # Compute instance loss 800 # Compute instance loss
772 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() 801 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
773 802
774 # Compute prior loss 803 # Compute prior loss
775 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 804 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
@@ -785,7 +814,7 @@ def main():
785 814
786 if args.find_lr: 815 if args.find_lr:
787 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 816 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
788 lr_finder.run(min_lr=1e-6, num_train_batches=4) 817 lr_finder.run(min_lr=1e-6, num_train_batches=1)
789 818
790 plt.savefig(basepath.joinpath("lr.png")) 819 plt.savefig(basepath.joinpath("lr.png"))
791 plt.close() 820 plt.close()
@@ -798,6 +827,10 @@ def main():
798 config = vars(args).copy() 827 config = vars(args).copy()
799 config["initializer_token"] = " ".join(config["initializer_token"]) 828 config["initializer_token"] = " ".join(config["initializer_token"])
800 config["placeholder_token"] = " ".join(config["placeholder_token"]) 829 config["placeholder_token"] = " ".join(config["placeholder_token"])
830 if config["exclude_modes"] is not None:
831 config["exclude_modes"] = " ".join(config["exclude_modes"])
832 if config["exclude_keywords"] is not None:
833 config["exclude_keywords"] = " ".join(config["exclude_keywords"])
801 accelerator.init_trackers("textual_inversion", config=config) 834 accelerator.init_trackers("textual_inversion", config=config)
802 835
803 # Train! 836 # Train!
diff --git a/training/lr.py b/training/lr.py
index ef01906..0c5ce9e 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -43,9 +43,6 @@ class LRFinder():
43 ) 43 )
44 progress_bar.set_description("Epoch X / Y") 44 progress_bar.set_description("Epoch X / Y")
45 45
46 train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches]
47 val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches]
48
49 for epoch in range(num_epochs): 46 for epoch in range(num_epochs):
50 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 47 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
51 48
@@ -54,7 +51,10 @@ class LRFinder():
54 51
55 self.model.train() 52 self.model.train()
56 53
57 for batch in train_workload: 54 for step, batch in enumerate(self.train_dataloader):
55 if step >= num_train_batches:
56 break
57
58 with self.accelerator.accumulate(self.model): 58 with self.accelerator.accumulate(self.model):
59 loss, acc, bsz = self.loss_fn(batch) 59 loss, acc, bsz = self.loss_fn(batch)
60 60
@@ -69,7 +69,10 @@ class LRFinder():
69 self.model.eval() 69 self.model.eval()
70 70
71 with torch.inference_mode(): 71 with torch.inference_mode():
72 for batch in val_workload: 72 for step, batch in enumerate(self.val_dataloader):
73 if step >= num_val_batches:
74 break
75
73 loss, acc, bsz = self.loss_fn(batch) 76 loss, acc, bsz = self.loss_fn(batch)
74 avg_loss.update(loss.detach_(), bsz) 77 avg_loss.update(loss.detach_(), bsz)
75 avg_acc.update(acc.detach_(), bsz) 78 avg_acc.update(acc.detach_(), bsz)