diff options
| -rw-r--r-- | data/csv.py | 67 | ||||
| -rw-r--r-- | train_dreambooth.py | 1 | ||||
| -rw-r--r-- | train_lora.py | 1 | ||||
| -rw-r--r-- | train_ti.py | 16 | ||||
| -rw-r--r-- | training/functional.py | 19 |
5 files changed, 70 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py index d52d251..9770bec 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -178,6 +178,7 @@ class VlpnDataModule(): | |||
| 178 | shuffle: bool = False, | 178 | shuffle: bool = False, |
| 179 | interpolation: str = "bicubic", | 179 | interpolation: str = "bicubic", |
| 180 | template_key: str = "template", | 180 | template_key: str = "template", |
| 181 | placeholder_tokens: list[str] = [], | ||
| 181 | valid_set_size: Optional[int] = None, | 182 | valid_set_size: Optional[int] = None, |
| 182 | train_set_pad: Optional[int] = None, | 183 | train_set_pad: Optional[int] = None, |
| 183 | valid_set_pad: Optional[int] = None, | 184 | valid_set_pad: Optional[int] = None, |
| @@ -195,6 +196,7 @@ class VlpnDataModule(): | |||
| 195 | self.data_root = self.data_file.parent | 196 | self.data_root = self.data_file.parent |
| 196 | self.class_root = self.data_root / class_subdir | 197 | self.class_root = self.data_root / class_subdir |
| 197 | self.class_root.mkdir(parents=True, exist_ok=True) | 198 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 199 | self.placeholder_tokens = placeholder_tokens | ||
| 198 | self.num_class_images = num_class_images | 200 | self.num_class_images = num_class_images |
| 199 | self.with_guidance = with_guidance | 201 | self.with_guidance = with_guidance |
| 200 | 202 | ||
| @@ -217,31 +219,50 @@ class VlpnDataModule(): | |||
| 217 | self.dtype = dtype | 219 | self.dtype = dtype |
| 218 | 220 | ||
| 219 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 221 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 220 | image = template["image"] if "image" in template else "{}" | 222 | tpl_image = template["image"] if "image" in template else "{}" |
| 221 | prompt = template["prompt"] if "prompt" in template else "{content}" | 223 | tpl_prompt = template["prompt"] if "prompt" in template else "{content}" |
| 222 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 224 | tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
| 223 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 225 | tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 224 | 226 | ||
| 225 | return [ | 227 | items = [] |
| 226 | VlpnDataItem( | 228 | |
| 227 | self.data_root / image.format(item["image"]), | 229 | for item in data: |
| 228 | None, | 230 | image = tpl_image.format(item["image"]) |
| 229 | prompt_to_keywords( | 231 | prompt = item["prompt"] if "prompt" in item else "" |
| 230 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 232 | nprompt = item["nprompt"] if "nprompt" in item else "" |
| 231 | expansions | 233 | collection = item["collection"].split(", ") if "collection" in item else [] |
| 232 | ), | 234 | |
| 233 | keywords_to_prompt(prompt_to_keywords( | 235 | prompt_keywords = prompt_to_keywords( |
| 234 | cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 236 | tpl_prompt.format(**prepare_prompt(prompt)), |
| 235 | expansions | 237 | expansions |
| 236 | )), | ||
| 237 | prompt_to_keywords( | ||
| 238 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | ||
| 239 | expansions | ||
| 240 | ), | ||
| 241 | item["collection"].split(", ") if "collection" in item else [] | ||
| 242 | ) | 238 | ) |
| 243 | for item in data | 239 | |
| 244 | ] | 240 | cprompt = keywords_to_prompt(prompt_to_keywords( |
| 241 | tpl_cprompt.format(**prepare_prompt(prompt)), | ||
| 242 | expansions | ||
| 243 | )) | ||
| 244 | |||
| 245 | inverted_tokens = keywords_to_prompt([ | ||
| 246 | f"inv_{token}" | ||
| 247 | for token in self.placeholder_tokens | ||
| 248 | if token in prompt_keywords | ||
| 249 | ]) | ||
| 250 | |||
| 251 | nprompt_keywords = prompt_to_keywords( | ||
| 252 | tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)), | ||
| 253 | expansions | ||
| 254 | ) | ||
| 255 | |||
| 256 | items.append(VlpnDataItem( | ||
| 257 | self.data_root / image, | ||
| 258 | None, | ||
| 259 | prompt_keywords, | ||
| 260 | cprompt, | ||
| 261 | nprompt_keywords, | ||
| 262 | collection | ||
| 263 | )) | ||
| 264 | |||
| 265 | return items | ||
| 245 | 266 | ||
| 246 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: | 267 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: |
| 247 | if self.filter is None: | 268 | if self.filter is None: |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 7a33bca..9345797 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -585,6 +585,7 @@ def main(): | |||
| 585 | dropout=args.tag_dropout, | 585 | dropout=args.tag_dropout, |
| 586 | shuffle=not args.no_tag_shuffle, | 586 | shuffle=not args.no_tag_shuffle, |
| 587 | template_key=args.train_data_template, | 587 | template_key=args.train_data_template, |
| 588 | placeholder_tokens=args.placeholder_tokens, | ||
| 588 | valid_set_size=args.valid_set_size, | 589 | valid_set_size=args.valid_set_size, |
| 589 | train_set_pad=args.train_set_pad, | 590 | train_set_pad=args.train_set_pad, |
| 590 | valid_set_pad=args.valid_set_pad, | 591 | valid_set_pad=args.valid_set_pad, |
diff --git a/train_lora.py b/train_lora.py index 684d0cc..7ecddf0 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -617,6 +617,7 @@ def main(): | |||
| 617 | dropout=args.tag_dropout, | 617 | dropout=args.tag_dropout, |
| 618 | shuffle=not args.no_tag_shuffle, | 618 | shuffle=not args.no_tag_shuffle, |
| 619 | template_key=args.train_data_template, | 619 | template_key=args.train_data_template, |
| 620 | placeholder_tokens=args.placeholder_tokens, | ||
| 620 | valid_set_size=args.valid_set_size, | 621 | valid_set_size=args.valid_set_size, |
| 621 | train_set_pad=args.train_set_pad, | 622 | train_set_pad=args.train_set_pad, |
| 622 | valid_set_pad=args.valid_set_pad, | 623 | valid_set_pad=args.valid_set_pad, |
diff --git a/train_ti.py b/train_ti.py index 83ad46d..6c35d41 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -81,6 +81,12 @@ def parse_args(): | |||
| 81 | help="Tokens to create an alias for." | 81 | help="Tokens to create an alias for." |
| 82 | ) | 82 | ) |
| 83 | parser.add_argument( | 83 | parser.add_argument( |
| 84 | "--inverted_initializer_tokens", | ||
| 85 | type=str, | ||
| 86 | nargs='*', | ||
| 87 | help="A token to use as initializer word." | ||
| 88 | ) | ||
| 89 | parser.add_argument( | ||
| 84 | "--num_vectors", | 90 | "--num_vectors", |
| 85 | type=int, | 91 | type=int, |
| 86 | nargs='*', | 92 | nargs='*', |
| @@ -149,7 +155,7 @@ def parse_args(): | |||
| 149 | parser.add_argument( | 155 | parser.add_argument( |
| 150 | "--num_buckets", | 156 | "--num_buckets", |
| 151 | type=int, | 157 | type=int, |
| 152 | default=0, | 158 | default=2, |
| 153 | help="Number of aspect ratio buckets in either direction.", | 159 | help="Number of aspect ratio buckets in either direction.", |
| 154 | ) | 160 | ) |
| 155 | parser.add_argument( | 161 | parser.add_argument( |
| @@ -488,6 +494,13 @@ def parse_args(): | |||
| 488 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 494 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 489 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 495 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 490 | 496 | ||
| 497 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 498 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
| 499 | |||
| 500 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
| 501 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 502 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 503 | |||
| 491 | if isinstance(args.num_vectors, int): | 504 | if isinstance(args.num_vectors, int): |
| 492 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 505 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 493 | 506 | ||
| @@ -720,6 +733,7 @@ def main(): | |||
| 720 | dropout=args.tag_dropout, | 733 | dropout=args.tag_dropout, |
| 721 | shuffle=not args.no_tag_shuffle, | 734 | shuffle=not args.no_tag_shuffle, |
| 722 | template_key=data_template, | 735 | template_key=data_template, |
| 736 | placeholder_tokens=args.placeholder_tokens, | ||
| 723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
| 724 | train_set_pad=args.train_set_pad, | 738 | train_set_pad=args.train_set_pad, |
| 725 | valid_set_pad=args.valid_set_pad, | 739 | valid_set_pad=args.valid_set_pad, |
diff --git a/training/functional.py b/training/functional.py index 109845b..a2aa24e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -335,14 +335,6 @@ def loss_step( | |||
| 335 | # Predict the noise residual | 335 | # Predict the noise residual |
| 336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 337 | 337 | ||
| 338 | # Get the target for loss depending on the prediction type | ||
| 339 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 340 | target = noise | ||
| 341 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 342 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 343 | else: | ||
| 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 345 | |||
| 346 | if guidance_scale != 0: | 338 | if guidance_scale != 0: |
| 347 | uncond_encoder_hidden_states = get_extended_embeddings( | 339 | uncond_encoder_hidden_states = get_extended_embeddings( |
| 348 | text_encoder, | 340 | text_encoder, |
| @@ -354,8 +346,15 @@ def loss_step( | |||
| 354 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 346 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample |
| 355 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 347 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
| 356 | 348 | ||
| 357 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 349 | # Get the target for loss depending on the prediction type |
| 358 | elif prior_loss_weight != 0: | 350 | if noise_scheduler.config.prediction_type == "epsilon": |
| 351 | target = noise | ||
| 352 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 353 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 354 | else: | ||
| 355 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 356 | |||
| 357 | if guidance_scale == 0 and prior_loss_weight != 0: | ||
| 359 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 358 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 360 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 359 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 361 | target, target_prior = torch.chunk(target, 2, dim=0) | 360 | target, target_prior = torch.chunk(target, 2, dim=0) |
