summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py41
1 files changed, 24 insertions, 17 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 0d5a742..69d9c7f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -55,9 +55,9 @@ def parse_args():
55 help="A CSV file containing the training data." 55 help="A CSV file containing the training data."
56 ) 56 )
57 parser.add_argument( 57 parser.add_argument(
58 "--placeholder_token", 58 "--instance_identifier",
59 type=str, 59 type=str,
60 default="<*>", 60 default=None,
61 help="A token to use as a placeholder for the concept.", 61 help="A token to use as a placeholder for the concept.",
62 ) 62 )
63 parser.add_argument( 63 parser.add_argument(
@@ -67,6 +67,12 @@ def parse_args():
67 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
68 ) 68 )
69 parser.add_argument( 69 parser.add_argument(
70 "--placeholder_token",
71 type=str,
72 default="<*>",
73 help="A token to use as a placeholder for the concept.",
74 )
75 parser.add_argument(
70 "--initializer_token", 76 "--initializer_token",
71 type=str, 77 type=str,
72 default=None, 78 default=None,
@@ -333,6 +339,7 @@ class Checkpointer:
333 unet, 339 unet,
334 tokenizer, 340 tokenizer,
335 text_encoder, 341 text_encoder,
342 instance_identifier,
336 placeholder_token, 343 placeholder_token,
337 placeholder_token_id, 344 placeholder_token_id,
338 output_dir: Path, 345 output_dir: Path,
@@ -347,6 +354,7 @@ class Checkpointer:
347 self.unet = unet 354 self.unet = unet
348 self.tokenizer = tokenizer 355 self.tokenizer = tokenizer
349 self.text_encoder = text_encoder 356 self.text_encoder = text_encoder
357 self.instance_identifier = instance_identifier
350 self.placeholder_token = placeholder_token 358 self.placeholder_token = placeholder_token
351 self.placeholder_token_id = placeholder_token_id 359 self.placeholder_token_id = placeholder_token_id
352 self.output_dir = output_dir 360 self.output_dir = output_dir
@@ -413,7 +421,7 @@ class Checkpointer:
413 421
414 for i in range(self.sample_batches): 422 for i in range(self.sample_batches):
415 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 423 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
416 prompt = [prompt.format(self.placeholder_token) 424 prompt = [prompt.format(self.instance_identifier)
417 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 425 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
418 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 426 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
419 427
@@ -428,7 +436,7 @@ class Checkpointer:
428 eta=eta, 436 eta=eta,
429 num_inference_steps=num_inference_steps, 437 num_inference_steps=num_inference_steps,
430 output_type='pil' 438 output_type='pil'
431 )["sample"] 439 ).images
432 440
433 all_samples += samples 441 all_samples += samples
434 442
@@ -480,28 +488,26 @@ def main():
480 if args.seed is not None: 488 if args.seed is not None:
481 set_seed(args.seed) 489 set_seed(args.seed)
482 490
491 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
492
483 # Load the tokenizer and add the placeholder token as a additional special token 493 # Load the tokenizer and add the placeholder token as a additional special token
484 if args.tokenizer_name: 494 if args.tokenizer_name:
485 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 495 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
486 elif args.pretrained_model_name_or_path: 496 elif args.pretrained_model_name_or_path:
487 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 497 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
488 498
499 # Convert the initializer_token, placeholder_token to ids
500 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
501 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.")
502 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
503
489 # Add the placeholder token in tokenizer 504 # Add the placeholder token in tokenizer
490 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 505 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
491 if num_added_tokens == 0: 506 if num_added_tokens == 0:
492 raise ValueError( 507 print(f"Re-using existing token {args.placeholder_token}.")
493 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 508 else:
494 " `placeholder_token` that is not already in the tokenizer." 509 print(f"Training new token {args.placeholder_token}.")
495 )
496
497 # Convert the initializer_token, placeholder_token to ids
498 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
499 # Check if initializer_token is a single token or a sequence of tokens
500 if len(initializer_token_ids) > 1:
501 raise ValueError(
502 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
503 510
504 initializer_token_ids = torch.tensor(initializer_token_ids)
505 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 511 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
506 512
507 # Load models and create wrapper for stable diffusion 513 # Load models and create wrapper for stable diffusion
@@ -602,7 +608,7 @@ def main():
602 data_file=args.train_data_file, 608 data_file=args.train_data_file,
603 batch_size=args.train_batch_size, 609 batch_size=args.train_batch_size,
604 tokenizer=tokenizer, 610 tokenizer=tokenizer,
605 instance_identifier=args.placeholder_token, 611 instance_identifier=args.instance_identifier,
606 class_identifier=args.class_identifier, 612 class_identifier=args.class_identifier,
607 class_subdir="cls", 613 class_subdir="cls",
608 num_class_images=args.num_class_images, 614 num_class_images=args.num_class_images,
@@ -730,6 +736,7 @@ def main():
730 unet=unet, 736 unet=unet,
731 tokenizer=tokenizer, 737 tokenizer=tokenizer,
732 text_encoder=text_encoder, 738 text_encoder=text_encoder,
739 instance_identifier=args.instance_identifier,
733 placeholder_token=args.placeholder_token, 740 placeholder_token=args.placeholder_token,
734 placeholder_token_id=placeholder_token_id, 741 placeholder_token_id=placeholder_token_id,
735 output_dir=basepath, 742 output_dir=basepath,