diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -55,9 +55,9 @@ def parse_args(): | |||
| 55 | help="A CSV file containing the training data." | 55 | help="A CSV file containing the training data." |
| 56 | ) | 56 | ) |
| 57 | parser.add_argument( | 57 | parser.add_argument( |
| 58 | "--placeholder_token", | 58 | "--instance_identifier", |
| 59 | type=str, | 59 | type=str, |
| 60 | default="<*>", | 60 | default=None, |
| 61 | help="A token to use as a placeholder for the concept.", | 61 | help="A token to use as a placeholder for the concept.", |
| 62 | ) | 62 | ) |
| 63 | parser.add_argument( | 63 | parser.add_argument( |
| @@ -67,6 +67,12 @@ def parse_args(): | |||
| 67 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
| 68 | ) | 68 | ) |
| 69 | parser.add_argument( | 69 | parser.add_argument( |
| 70 | "--placeholder_token", | ||
| 71 | type=str, | ||
| 72 | default="<*>", | ||
| 73 | help="A token to use as a placeholder for the concept.", | ||
| 74 | ) | ||
| 75 | parser.add_argument( | ||
| 70 | "--initializer_token", | 76 | "--initializer_token", |
| 71 | type=str, | 77 | type=str, |
| 72 | default=None, | 78 | default=None, |
| @@ -333,6 +339,7 @@ class Checkpointer: | |||
| 333 | unet, | 339 | unet, |
| 334 | tokenizer, | 340 | tokenizer, |
| 335 | text_encoder, | 341 | text_encoder, |
| 342 | instance_identifier, | ||
| 336 | placeholder_token, | 343 | placeholder_token, |
| 337 | placeholder_token_id, | 344 | placeholder_token_id, |
| 338 | output_dir: Path, | 345 | output_dir: Path, |
| @@ -347,6 +354,7 @@ class Checkpointer: | |||
| 347 | self.unet = unet | 354 | self.unet = unet |
| 348 | self.tokenizer = tokenizer | 355 | self.tokenizer = tokenizer |
| 349 | self.text_encoder = text_encoder | 356 | self.text_encoder = text_encoder |
| 357 | self.instance_identifier = instance_identifier | ||
| 350 | self.placeholder_token = placeholder_token | 358 | self.placeholder_token = placeholder_token |
| 351 | self.placeholder_token_id = placeholder_token_id | 359 | self.placeholder_token_id = placeholder_token_id |
| 352 | self.output_dir = output_dir | 360 | self.output_dir = output_dir |
| @@ -413,7 +421,7 @@ class Checkpointer: | |||
| 413 | 421 | ||
| 414 | for i in range(self.sample_batches): | 422 | for i in range(self.sample_batches): |
| 415 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 423 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 416 | prompt = [prompt.format(self.placeholder_token) | 424 | prompt = [prompt.format(self.instance_identifier) |
| 417 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 425 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 418 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 426 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 419 | 427 | ||
| @@ -428,7 +436,7 @@ class Checkpointer: | |||
| 428 | eta=eta, | 436 | eta=eta, |
| 429 | num_inference_steps=num_inference_steps, | 437 | num_inference_steps=num_inference_steps, |
| 430 | output_type='pil' | 438 | output_type='pil' |
| 431 | )["sample"] | 439 | ).images |
| 432 | 440 | ||
| 433 | all_samples += samples | 441 | all_samples += samples |
| 434 | 442 | ||
| @@ -480,28 +488,26 @@ def main(): | |||
| 480 | if args.seed is not None: | 488 | if args.seed is not None: |
| 481 | set_seed(args.seed) | 489 | set_seed(args.seed) |
| 482 | 490 | ||
| 491 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
| 492 | |||
| 483 | # Load the tokenizer and add the placeholder token as a additional special token | 493 | # Load the tokenizer and add the placeholder token as a additional special token |
| 484 | if args.tokenizer_name: | 494 | if args.tokenizer_name: |
| 485 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 495 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 486 | elif args.pretrained_model_name_or_path: | 496 | elif args.pretrained_model_name_or_path: |
| 487 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 497 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 488 | 498 | ||
| 499 | # Convert the initializer_token, placeholder_token to ids | ||
| 500 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 501 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | ||
| 502 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
| 503 | |||
| 489 | # Add the placeholder token in tokenizer | 504 | # Add the placeholder token in tokenizer |
| 490 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 505 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 491 | if num_added_tokens == 0: | 506 | if num_added_tokens == 0: |
| 492 | raise ValueError( | 507 | print(f"Re-using existing token {args.placeholder_token}.") |
| 493 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | 508 | else: |
| 494 | " `placeholder_token` that is not already in the tokenizer." | 509 | print(f"Training new token {args.placeholder_token}.") |
| 495 | ) | ||
| 496 | |||
| 497 | # Convert the initializer_token, placeholder_token to ids | ||
| 498 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 499 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 500 | if len(initializer_token_ids) > 1: | ||
| 501 | raise ValueError( | ||
| 502 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
| 503 | 510 | ||
| 504 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 505 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 511 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 506 | 512 | ||
| 507 | # Load models and create wrapper for stable diffusion | 513 | # Load models and create wrapper for stable diffusion |
| @@ -602,7 +608,7 @@ def main(): | |||
| 602 | data_file=args.train_data_file, | 608 | data_file=args.train_data_file, |
| 603 | batch_size=args.train_batch_size, | 609 | batch_size=args.train_batch_size, |
| 604 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
| 605 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.instance_identifier, |
| 606 | class_identifier=args.class_identifier, | 612 | class_identifier=args.class_identifier, |
| 607 | class_subdir="cls", | 613 | class_subdir="cls", |
| 608 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
| @@ -730,6 +736,7 @@ def main(): | |||
| 730 | unet=unet, | 736 | unet=unet, |
| 731 | tokenizer=tokenizer, | 737 | tokenizer=tokenizer, |
| 732 | text_encoder=text_encoder, | 738 | text_encoder=text_encoder, |
| 739 | instance_identifier=args.instance_identifier, | ||
| 733 | placeholder_token=args.placeholder_token, | 740 | placeholder_token=args.placeholder_token, |
| 734 | placeholder_token_id=placeholder_token_id, | 741 | placeholder_token_id=placeholder_token_id, |
| 735 | output_dir=basepath, | 742 | output_dir=basepath, |
