diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -55,9 +55,9 @@ def parse_args(): | |||
55 | help="A CSV file containing the training data." | 55 | help="A CSV file containing the training data." |
56 | ) | 56 | ) |
57 | parser.add_argument( | 57 | parser.add_argument( |
58 | "--placeholder_token", | 58 | "--instance_identifier", |
59 | type=str, | 59 | type=str, |
60 | default="<*>", | 60 | default=None, |
61 | help="A token to use as a placeholder for the concept.", | 61 | help="A token to use as a placeholder for the concept.", |
62 | ) | 62 | ) |
63 | parser.add_argument( | 63 | parser.add_argument( |
@@ -67,6 +67,12 @@ def parse_args(): | |||
67 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
68 | ) | 68 | ) |
69 | parser.add_argument( | 69 | parser.add_argument( |
70 | "--placeholder_token", | ||
71 | type=str, | ||
72 | default="<*>", | ||
73 | help="A token to use as a placeholder for the concept.", | ||
74 | ) | ||
75 | parser.add_argument( | ||
70 | "--initializer_token", | 76 | "--initializer_token", |
71 | type=str, | 77 | type=str, |
72 | default=None, | 78 | default=None, |
@@ -333,6 +339,7 @@ class Checkpointer: | |||
333 | unet, | 339 | unet, |
334 | tokenizer, | 340 | tokenizer, |
335 | text_encoder, | 341 | text_encoder, |
342 | instance_identifier, | ||
336 | placeholder_token, | 343 | placeholder_token, |
337 | placeholder_token_id, | 344 | placeholder_token_id, |
338 | output_dir: Path, | 345 | output_dir: Path, |
@@ -347,6 +354,7 @@ class Checkpointer: | |||
347 | self.unet = unet | 354 | self.unet = unet |
348 | self.tokenizer = tokenizer | 355 | self.tokenizer = tokenizer |
349 | self.text_encoder = text_encoder | 356 | self.text_encoder = text_encoder |
357 | self.instance_identifier = instance_identifier | ||
350 | self.placeholder_token = placeholder_token | 358 | self.placeholder_token = placeholder_token |
351 | self.placeholder_token_id = placeholder_token_id | 359 | self.placeholder_token_id = placeholder_token_id |
352 | self.output_dir = output_dir | 360 | self.output_dir = output_dir |
@@ -413,7 +421,7 @@ class Checkpointer: | |||
413 | 421 | ||
414 | for i in range(self.sample_batches): | 422 | for i in range(self.sample_batches): |
415 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 423 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
416 | prompt = [prompt.format(self.placeholder_token) | 424 | prompt = [prompt.format(self.instance_identifier) |
417 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 425 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
418 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 426 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
419 | 427 | ||
@@ -428,7 +436,7 @@ class Checkpointer: | |||
428 | eta=eta, | 436 | eta=eta, |
429 | num_inference_steps=num_inference_steps, | 437 | num_inference_steps=num_inference_steps, |
430 | output_type='pil' | 438 | output_type='pil' |
431 | )["sample"] | 439 | ).images |
432 | 440 | ||
433 | all_samples += samples | 441 | all_samples += samples |
434 | 442 | ||
@@ -480,28 +488,26 @@ def main(): | |||
480 | if args.seed is not None: | 488 | if args.seed is not None: |
481 | set_seed(args.seed) | 489 | set_seed(args.seed) |
482 | 490 | ||
491 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
492 | |||
483 | # Load the tokenizer and add the placeholder token as a additional special token | 493 | # Load the tokenizer and add the placeholder token as a additional special token |
484 | if args.tokenizer_name: | 494 | if args.tokenizer_name: |
485 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 495 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
486 | elif args.pretrained_model_name_or_path: | 496 | elif args.pretrained_model_name_or_path: |
487 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 497 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
488 | 498 | ||
499 | # Convert the initializer_token, placeholder_token to ids | ||
500 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
501 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
502 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
503 | |||
489 | # Add the placeholder token in tokenizer | 504 | # Add the placeholder token in tokenizer |
490 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 505 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
491 | if num_added_tokens == 0: | 506 | if num_added_tokens == 0: |
492 | raise ValueError( | 507 | print(f"Re-using existing token {args.placeholder_token}.") |
493 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 508 | else: |
494 | " `placeholder_token` that is not already in the tokenizer." | 509 | print(f"Training new token {args.placeholder_token}.") |
495 | ) | ||
496 | |||
497 | # Convert the initializer_token, placeholder_token to ids | ||
498 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
499 | # Check if initializer_token is a single token or a sequence of tokens | ||
500 | if len(initializer_token_ids) > 1: | ||
501 | raise ValueError( | ||
502 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
503 | 510 | ||
504 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
505 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 511 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
506 | 512 | ||
507 | # Load models and create wrapper for stable diffusion | 513 | # Load models and create wrapper for stable diffusion |
@@ -602,7 +608,7 @@ def main(): | |||
602 | data_file=args.train_data_file, | 608 | data_file=args.train_data_file, |
603 | batch_size=args.train_batch_size, | 609 | batch_size=args.train_batch_size, |
604 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
605 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.instance_identifier, |
606 | class_identifier=args.class_identifier, | 612 | class_identifier=args.class_identifier, |
607 | class_subdir="cls", | 613 | class_subdir="cls", |
608 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
@@ -730,6 +736,7 @@ def main(): | |||
730 | unet=unet, | 736 | unet=unet, |
731 | tokenizer=tokenizer, | 737 | tokenizer=tokenizer, |
732 | text_encoder=text_encoder, | 738 | text_encoder=text_encoder, |
739 | instance_identifier=args.instance_identifier, | ||
733 | placeholder_token=args.placeholder_token, | 740 | placeholder_token=args.placeholder_token, |
734 | placeholder_token_id=placeholder_token_id, | 741 | placeholder_token_id=placeholder_token_id, |
735 | output_dir=basepath, | 742 | output_dir=basepath, |