summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py67
-rw-r--r--train_dreambooth.py1
-rw-r--r--train_lora.py1
-rw-r--r--train_ti.py16
-rw-r--r--training/functional.py19
5 files changed, 70 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py
index d52d251..9770bec 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -178,6 +178,7 @@ class VlpnDataModule():
178 shuffle: bool = False, 178 shuffle: bool = False,
179 interpolation: str = "bicubic", 179 interpolation: str = "bicubic",
180 template_key: str = "template", 180 template_key: str = "template",
181 placeholder_tokens: list[str] = [],
181 valid_set_size: Optional[int] = None, 182 valid_set_size: Optional[int] = None,
182 train_set_pad: Optional[int] = None, 183 train_set_pad: Optional[int] = None,
183 valid_set_pad: Optional[int] = None, 184 valid_set_pad: Optional[int] = None,
@@ -195,6 +196,7 @@ class VlpnDataModule():
195 self.data_root = self.data_file.parent 196 self.data_root = self.data_file.parent
196 self.class_root = self.data_root / class_subdir 197 self.class_root = self.data_root / class_subdir
197 self.class_root.mkdir(parents=True, exist_ok=True) 198 self.class_root.mkdir(parents=True, exist_ok=True)
199 self.placeholder_tokens = placeholder_tokens
198 self.num_class_images = num_class_images 200 self.num_class_images = num_class_images
199 self.with_guidance = with_guidance 201 self.with_guidance = with_guidance
200 202
@@ -217,31 +219,50 @@ class VlpnDataModule():
217 self.dtype = dtype 219 self.dtype = dtype
218 220
219 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 221 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
220 image = template["image"] if "image" in template else "{}" 222 tpl_image = template["image"] if "image" in template else "{}"
221 prompt = template["prompt"] if "prompt" in template else "{content}" 223 tpl_prompt = template["prompt"] if "prompt" in template else "{content}"
222 cprompt = template["cprompt"] if "cprompt" in template else "{content}" 224 tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}"
223 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 225 tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}"
226
227 items = []
228
229 for item in data:
230 image = tpl_image.format(item["image"])
231 prompt = item["prompt"] if "prompt" in item else ""
232 nprompt = item["nprompt"] if "nprompt" in item else ""
233 collection = item["collection"].split(", ") if "collection" in item else []
234
235 prompt_keywords = prompt_to_keywords(
236 tpl_prompt.format(**prepare_prompt(prompt)),
237 expansions
238 )
224 239
225 return [ 240 cprompt = keywords_to_prompt(prompt_to_keywords(
226 VlpnDataItem( 241 tpl_cprompt.format(**prepare_prompt(prompt)),
227 self.data_root / image.format(item["image"]), 242 expansions
228 None, 243 ))
229 prompt_to_keywords( 244
230 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 245 inverted_tokens = keywords_to_prompt([
231 expansions 246 f"inv_{token}"
232 ), 247 for token in self.placeholder_tokens
233 keywords_to_prompt(prompt_to_keywords( 248 if token in prompt_keywords
234 cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 249 ])
235 expansions 250
236 )), 251 nprompt_keywords = prompt_to_keywords(
237 prompt_to_keywords( 252 tpl_nprompt.format(_inv=inverted_tokens, **prepare_prompt(nprompt)),
238 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 253 expansions
239 expansions
240 ),
241 item["collection"].split(", ") if "collection" in item else []
242 ) 254 )
243 for item in data 255
244 ] 256 items.append(VlpnDataItem(
257 self.data_root / image,
258 None,
259 prompt_keywords,
260 cprompt,
261 nprompt_keywords,
262 collection
263 ))
264
265 return items
245 266
246 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: 267 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]:
247 if self.filter is None: 268 if self.filter is None:
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 7a33bca..9345797 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -585,6 +585,7 @@ def main():
585 dropout=args.tag_dropout, 585 dropout=args.tag_dropout,
586 shuffle=not args.no_tag_shuffle, 586 shuffle=not args.no_tag_shuffle,
587 template_key=args.train_data_template, 587 template_key=args.train_data_template,
588 placeholder_tokens=args.placeholder_tokens,
588 valid_set_size=args.valid_set_size, 589 valid_set_size=args.valid_set_size,
589 train_set_pad=args.train_set_pad, 590 train_set_pad=args.train_set_pad,
590 valid_set_pad=args.valid_set_pad, 591 valid_set_pad=args.valid_set_pad,
diff --git a/train_lora.py b/train_lora.py
index 684d0cc..7ecddf0 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -617,6 +617,7 @@ def main():
617 dropout=args.tag_dropout, 617 dropout=args.tag_dropout,
618 shuffle=not args.no_tag_shuffle, 618 shuffle=not args.no_tag_shuffle,
619 template_key=args.train_data_template, 619 template_key=args.train_data_template,
620 placeholder_tokens=args.placeholder_tokens,
620 valid_set_size=args.valid_set_size, 621 valid_set_size=args.valid_set_size,
621 train_set_pad=args.train_set_pad, 622 train_set_pad=args.train_set_pad,
622 valid_set_pad=args.valid_set_pad, 623 valid_set_pad=args.valid_set_pad,
diff --git a/train_ti.py b/train_ti.py
index 83ad46d..6c35d41 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -81,6 +81,12 @@ def parse_args():
81 help="Tokens to create an alias for." 81 help="Tokens to create an alias for."
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
84 "--inverted_initializer_tokens",
85 type=str,
86 nargs='*',
87 help="A token to use as initializer word."
88 )
89 parser.add_argument(
84 "--num_vectors", 90 "--num_vectors",
85 type=int, 91 type=int,
86 nargs='*', 92 nargs='*',
@@ -149,7 +155,7 @@ def parse_args():
149 parser.add_argument( 155 parser.add_argument(
150 "--num_buckets", 156 "--num_buckets",
151 type=int, 157 type=int,
152 default=0, 158 default=2,
153 help="Number of aspect ratio buckets in either direction.", 159 help="Number of aspect ratio buckets in either direction.",
154 ) 160 )
155 parser.add_argument( 161 parser.add_argument(
@@ -488,6 +494,13 @@ def parse_args():
488 if len(args.placeholder_tokens) != len(args.initializer_tokens): 494 if len(args.placeholder_tokens) != len(args.initializer_tokens):
489 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 495 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
490 496
497 if isinstance(args.inverted_initializer_tokens, str):
498 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens)
499
500 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0:
501 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
502 args.initializer_tokens += args.inverted_initializer_tokens
503
491 if isinstance(args.num_vectors, int): 504 if isinstance(args.num_vectors, int):
492 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 505 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
493 506
@@ -720,6 +733,7 @@ def main():
720 dropout=args.tag_dropout, 733 dropout=args.tag_dropout,
721 shuffle=not args.no_tag_shuffle, 734 shuffle=not args.no_tag_shuffle,
722 template_key=data_template, 735 template_key=data_template,
736 placeholder_tokens=args.placeholder_tokens,
723 valid_set_size=args.valid_set_size, 737 valid_set_size=args.valid_set_size,
724 train_set_pad=args.train_set_pad, 738 train_set_pad=args.train_set_pad,
725 valid_set_pad=args.valid_set_pad, 739 valid_set_pad=args.valid_set_pad,
diff --git a/training/functional.py b/training/functional.py
index 109845b..a2aa24e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -335,14 +335,6 @@ def loss_step(
335 # Predict the noise residual 335 # Predict the noise residual
336 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 336 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
337 337
338 # Get the target for loss depending on the prediction type
339 if noise_scheduler.config.prediction_type == "epsilon":
340 target = noise
341 elif noise_scheduler.config.prediction_type == "v_prediction":
342 target = noise_scheduler.get_velocity(latents, noise, timesteps)
343 else:
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345
346 if guidance_scale != 0: 338 if guidance_scale != 0:
347 uncond_encoder_hidden_states = get_extended_embeddings( 339 uncond_encoder_hidden_states = get_extended_embeddings(
348 text_encoder, 340 text_encoder,
@@ -354,8 +346,15 @@ def loss_step(
354 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 346 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
355 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 347 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
356 348
357 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 349 # Get the target for loss depending on the prediction type
358 elif prior_loss_weight != 0: 350 if noise_scheduler.config.prediction_type == "epsilon":
351 target = noise
352 elif noise_scheduler.config.prediction_type == "v_prediction":
353 target = noise_scheduler.get_velocity(latents, noise, timesteps)
354 else:
355 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
356
357 if guidance_scale == 0 and prior_loss_weight != 0:
359 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 358 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
360 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 359 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
361 target, target_prior = torch.chunk(target, 2, dim=0) 360 target, target_prior = torch.chunk(target, 2, dim=0)