From 19ae465203c8dcc0b1179584db632015362b5e44 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 26 Mar 2023 14:27:54 +0200 Subject: Improved inverted tokens --- data/csv.py | 67 +++++++++++++++++++++++++++++++++----------------- train_dreambooth.py | 1 + train_lora.py | 1 + train_ti.py | 16 +++++++++++- training/functional.py | 19 +++++++------- 5 files changed, 70 insertions(+), 34 deletions(-) diff --git a/data/csv.py b/data/csv.py index d52d251..9770bec 100644 --- a/data/csv.py +++ b/data/csv.py @@ -178,6 +178,7 @@ class VlpnDataModule(): shuffle: bool = False, interpolation: str = "bicubic", template_key: str = "template", + placeholder_tokens: list[str] = [], valid_set_size: Optional[int] = None, train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, @@ -195,6 +196,7 @@ class VlpnDataModule(): self.data_root = self.data_file.parent self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) + self.placeholder_tokens = placeholder_tokens self.num_class_images = num_class_images self.with_guidance = with_guidance @@ -217,31 +219,50 @@ class VlpnDataModule(): self.dtype = dtype def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: - image = template["image"] if "image" in template else "{}" - prompt = template["prompt"] if "prompt" in template else "{content}" - cprompt = template["cprompt"] if "cprompt" in template else "{content}" - nprompt = template["nprompt"] if "nprompt" in template else "{content}" + tpl_image = template["image"] if "image" in template else "{}" + tpl_prompt = template["prompt"] if "prompt" in template else "{content}" + tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" + tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" + + items = [] + + for item in data: + image = tpl_image.format(item["image"]) + prompt = item["prompt"] if "prompt" in item else "" + nprompt = item["nprompt"] if "nprompt" in item else "" + collection = item["collection"].split(", ") if "collection" in item else [] + + prompt_keywords = prompt_to_keywords( + tpl_prompt.format(**prepare_prompt(prompt)), + expansions + ) - return [ - VlpnDataItem( - self.data_root / image.format(item["image"]), - None, - prompt_to_keywords( - prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), - expansions - ), - keywords_to_prompt(prompt_to_keywords( - cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), - expansions - )), - prompt_to_keywords( - nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), - expansions - ), - item["collection"].split(", ") if "collection" in item else [] + cprompt = keywords_to_prompt(prompt_to_keywords( + tpl_cprompt.format(**prepare_prompt(prompt)), + expansions + )) + + inverted_tokens = keywords_to_prompt([ + f"inv_{token}" + for token in self.placeholder_tokens + if token in prompt_keywords + ]) + + nprompt_keywords = prompt_to_keywords( + tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), + expansions ) - for item in data - ] + + items.append(VlpnDataItem( + self.data_root / image, + None, + prompt_keywords, + cprompt, + nprompt_keywords, + collection + )) + + return items def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: diff --git a/train_dreambooth.py b/train_dreambooth.py index 7a33bca..9345797 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -585,6 +585,7 @@ def main(): dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, + placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, diff --git a/train_lora.py b/train_lora.py index 684d0cc..7ecddf0 100644 --- a/train_lora.py +++ b/train_lora.py @@ -617,6 +617,7 @@ def main(): dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, + placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, diff --git a/train_ti.py b/train_ti.py index 83ad46d..6c35d41 100644 --- a/train_ti.py +++ b/train_ti.py @@ -80,6 +80,12 @@ def parse_args(): default=[], help="Tokens to create an alias for." ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) parser.add_argument( "--num_vectors", type=int, @@ -149,7 +155,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=0, + default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -488,6 +494,13 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + + if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) @@ -720,6 +733,7 @@ def main(): dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=data_template, + placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, diff --git a/training/functional.py b/training/functional.py index 109845b..a2aa24e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -335,14 +335,6 @@ def loss_step( # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( text_encoder, @@ -354,8 +346,15 @@ def loss_step( model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - elif prior_loss_weight != 0: + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if guidance_scale == 0 and prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) -- cgit v1.2.3-70-g09d2