From dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 13:48:26 +0100 Subject: Training script improvements --- data/csv.py | 15 ++++----------- train_dreambooth.py | 41 ++++++++++++++++++++++++++++++++++++++--- train_lora.py | 2 +- train_ti.py | 43 ++++++++++++++++++++++++++++++++++++++----- training/lr.py | 13 ++++++++----- 5 files changed, 89 insertions(+), 25 deletions(-) diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): prompt: list[str] cprompt: str nprompt: str + mode: list[str] class CSVDataModule(): @@ -56,7 +57,6 @@ class CSVDataModule(): dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, - mode: Optional[str] = None, template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, @@ -81,7 +81,6 @@ class CSVDataModule(): self.repeats = repeats self.dropout = dropout self.center_crop = center_crop - self.mode = mode self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -113,6 +112,7 @@ class CSVDataModule(): nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), + item["mode"].split(", ") if "mode" in item else [] ) for item in data ] @@ -133,6 +133,7 @@ class CSVDataModule(): item.prompt, item.cprompt, item.nprompt, + item.mode, ) for item in items for i in range(image_multiplier) @@ -145,20 +146,12 @@ class CSVDataModule(): expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] - if self.mode is not None: - items = [ - item - for item in items - if "mode" in item and self.mode in item["mode"].split(", ") - ] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) - valid_set_size = int(num_images * 0.1) - if self.valid_set_size: - valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size diff --git a/train_dreambooth.py b/train_dreambooth.py index 202d52c..072150b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -22,7 +22,7 @@ from slugify import slugify from common import load_text_embeddings, load_config from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import CSVDataModule +from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, save_args @@ -82,6 +82,18 @@ def parse_args(): default=[], help="A token to use as initializer word." ) + parser.add_argument( + "--exclude_keywords", + type=str, + nargs='*', + help="Skip dataset items containing a listed keyword.", + ) + parser.add_argument( + "--exclude_modes", + type=str, + nargs='*', + help="Exclude all items with a listed mode.", + ) parser.add_argument( "--train_text_encoder", action="store_true", @@ -379,6 +391,12 @@ def parse_args(): if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("Number of items in --placeholder_token and --initializer_token must match") + if isinstance(args.exclude_keywords, str): + args.exclude_keywords = [args.exclude_keywords] + + if isinstance(args.exclude_modes, str): + args.exclude_modes = [args.exclude_modes] + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -636,6 +654,19 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + def keyword_filter(item: CSVDataItem): + cond2 = args.exclude_keywords is None or not any( + keyword in part + for keyword in args.exclude_keywords + for part in item.prompt + ) + cond3 = args.mode is None or args.mode in item.mode + cond4 = args.exclude_modes is None or not any( + mode in item.mode + for mode in args.exclude_modes + ) + return cond2 and cond3 and cond4 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] cprompts = [example["cprompts"] for example in examples] @@ -671,12 +702,12 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, - mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, + filter=keyword_filter, collate_fn=collate_fn ) @@ -782,6 +813,10 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) + if config["exclude_modes"] is not None: + config["exclude_modes"] = " ".join(config["exclude_modes"]) + if config["exclude_keywords"] is not None: + config["exclude_keywords"] = " ".join(config["exclude_keywords"]) accelerator.init_trackers("dreambooth", config=config) # Train! @@ -879,7 +914,7 @@ def main(): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") diff --git a/train_lora.py b/train_lora.py index 9a42cae..de878a4 100644 --- a/train_lora.py +++ b/train_lora.py @@ -810,7 +810,7 @@ def main(): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") diff --git a/train_ti.py b/train_ti.py index b1f6a49..6aa4007 100644 --- a/train_ti.py +++ b/train_ti.py @@ -92,6 +92,18 @@ def parse_args(): default="cls", help="The directory where class images will be saved.", ) + parser.add_argument( + "--exclude_keywords", + type=str, + nargs='*', + help="Skip dataset items containing a listed keyword.", + ) + parser.add_argument( + "--exclude_modes", + type=str, + nargs='*', + help="Exclude all items with a listed mode.", + ) parser.add_argument( "--repeats", type=int, @@ -120,7 +132,8 @@ def parse_args(): "--seed", type=int, default=None, - help="A seed for reproducible training.") + help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -356,6 +369,12 @@ def parse_args(): if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("You must specify --placeholder_token") + if isinstance(args.exclude_keywords, str): + args.exclude_keywords = [args.exclude_keywords] + + if isinstance(args.exclude_modes, str): + args.exclude_modes = [args.exclude_modes] + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -576,11 +595,22 @@ def main(): weight_dtype = torch.bfloat16 def keyword_filter(item: CSVDataItem): - return any( + cond1 = any( keyword in part for keyword in args.placeholder_token for part in item.prompt ) + cond2 = args.exclude_keywords is None or not any( + keyword in part + for keyword in args.exclude_keywords + for part in item.prompt + ) + cond3 = args.mode is None or args.mode in item.mode + cond4 = args.exclude_modes is None or not any( + mode in item.mode + for mode in args.exclude_modes + ) + return cond1 and cond2 and cond3 and cond4 def collate_fn(examples): prompts = [example["prompts"] for example in examples] @@ -617,7 +647,6 @@ def main(): num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, - mode=args.mode, dropout=args.tag_dropout, center_crop=args.center_crop, template_key=args.train_data_template, @@ -769,7 +798,7 @@ def main(): target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") @@ -785,7 +814,7 @@ def main(): if args.find_lr: lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) - lr_finder.run(min_lr=1e-6, num_train_batches=4) + lr_finder.run(min_lr=1e-6, num_train_batches=1) plt.savefig(basepath.joinpath("lr.png")) plt.close() @@ -798,6 +827,10 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) + if config["exclude_modes"] is not None: + config["exclude_modes"] = " ".join(config["exclude_modes"]) + if config["exclude_keywords"] is not None: + config["exclude_keywords"] = " ".join(config["exclude_keywords"]) accelerator.init_trackers("textual_inversion", config=config) # Train! diff --git a/training/lr.py b/training/lr.py index ef01906..0c5ce9e 100644 --- a/training/lr.py +++ b/training/lr.py @@ -43,9 +43,6 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") - train_workload = [batch for i, batch in enumerate(self.train_dataloader) if i < num_train_batches] - val_workload = [batch for i, batch in enumerate(self.val_dataloader) if i < num_val_batches] - for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -54,7 +51,10 @@ class LRFinder(): self.model.train() - for batch in train_workload: + for step, batch in enumerate(self.train_dataloader): + if step >= num_train_batches: + break + with self.accelerator.accumulate(self.model): loss, acc, bsz = self.loss_fn(batch) @@ -69,7 +69,10 @@ class LRFinder(): self.model.eval() with torch.inference_mode(): - for batch in val_workload: + for step, batch in enumerate(self.val_dataloader): + if step >= num_val_batches: + break + loss, acc, bsz = self.loss_fn(batch) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-70-g09d2