diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 58 |
1 files changed, 34 insertions, 24 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index c42762f..bcdfd3a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -70,13 +70,13 @@ def parse_args(): | |||
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--placeholder_token", | 71 | "--placeholder_token", |
72 | type=str, | 72 | type=str, |
73 | default="<*>", | 73 | nargs='*', |
74 | help="A token to use as a placeholder for the concept.", | 74 | help="A token to use as a placeholder for the concept.", |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--initializer_token", | 77 | "--initializer_token", |
78 | type=str, | 78 | type=str, |
79 | default=None, | 79 | nargs='*', |
80 | help="A token to use as initializer word." | 80 | help="A token to use as initializer word." |
81 | ) | 81 | ) |
82 | parser.add_argument( | 82 | parser.add_argument( |
@@ -299,12 +299,21 @@ def parse_args(): | |||
299 | if args.pretrained_model_name_or_path is None: | 299 | if args.pretrained_model_name_or_path is None: |
300 | raise ValueError("You must specify --pretrained_model_name_or_path") | 300 | raise ValueError("You must specify --pretrained_model_name_or_path") |
301 | 301 | ||
302 | if args.placeholder_token is None: | 302 | if isinstance(args.initializer_token, str): |
303 | raise ValueError("You must specify --placeholder_token") | 303 | args.initializer_token = [args.initializer_token] |
304 | 304 | ||
305 | if args.initializer_token is None: | 305 | if len(args.initializer_token) == 0: |
306 | raise ValueError("You must specify --initializer_token") | 306 | raise ValueError("You must specify --initializer_token") |
307 | 307 | ||
308 | if isinstance(args.placeholder_token, str): | ||
309 | args.placeholder_token = [args.placeholder_token] | ||
310 | |||
311 | if len(args.placeholder_token) == 0: | ||
312 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
313 | |||
314 | if len(args.placeholder_token) != len(args.initializer_token): | ||
315 | raise ValueError("You must specify --placeholder_token") | ||
316 | |||
308 | if args.output_dir is None: | 317 | if args.output_dir is None: |
309 | raise ValueError("You must specify --output_dir") | 318 | raise ValueError("You must specify --output_dir") |
310 | 319 | ||
@@ -373,12 +382,13 @@ class Checkpointer: | |||
373 | 382 | ||
374 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 383 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
375 | 384 | ||
376 | # Save a checkpoint | 385 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
377 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 386 | # Save a checkpoint |
378 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 387 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
388 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
379 | 389 | ||
380 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 390 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
381 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 391 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
382 | 392 | ||
383 | del unwrapped | 393 | del unwrapped |
384 | del learned_embeds | 394 | del learned_embeds |
@@ -422,7 +432,7 @@ class Checkpointer: | |||
422 | 432 | ||
423 | for i in range(self.sample_batches): | 433 | for i in range(self.sample_batches): |
424 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 434 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
425 | prompt = [prompt.format(self.instance_identifier) | 435 | prompt = [prompt.format(identifier=self.instance_identifier) |
426 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 436 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
427 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 437 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
428 | 438 | ||
@@ -498,16 +508,13 @@ def main(): | |||
498 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 508 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
499 | 509 | ||
500 | # Convert the initializer_token, placeholder_token to ids | 510 | # Convert the initializer_token, placeholder_token to ids |
501 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 511 | initializer_token_ids = torch.stack([ |
502 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | 512 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
503 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 513 | for token in args.initializer_token |
514 | ]) | ||
504 | 515 | ||
505 | # Add the placeholder token in tokenizer | ||
506 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 516 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
507 | if num_added_tokens == 0: | 517 | print(f"Added {num_added_tokens} new tokens.") |
508 | print(f"Re-using existing token {args.placeholder_token}.") | ||
509 | else: | ||
510 | print(f"Training new token {args.placeholder_token}.") | ||
511 | 518 | ||
512 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 519 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
513 | 520 | ||
@@ -533,11 +540,11 @@ def main(): | |||
533 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 540 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
534 | 541 | ||
535 | if args.resume_checkpoint is not None: | 542 | if args.resume_checkpoint is not None: |
536 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 543 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
537 | args.placeholder_token] | ||
538 | else: | 544 | else: |
539 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 545 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
540 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 546 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
547 | token_embeds[token_id] = embeddings | ||
541 | 548 | ||
542 | # Freeze vae and unet | 549 | # Freeze vae and unet |
543 | freeze_params(vae.parameters()) | 550 | freeze_params(vae.parameters()) |
@@ -648,7 +655,7 @@ def main(): | |||
648 | with torch.inference_mode(): | 655 | with torch.inference_mode(): |
649 | for batch in batched_data: | 656 | for batch in batched_data: |
650 | image_name = [p.class_image_path for p in batch] | 657 | image_name = [p.class_image_path for p in batch] |
651 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 658 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] |
652 | nprompt = [p.nprompt for p in batch] | 659 | nprompt = [p.nprompt for p in batch] |
653 | 660 | ||
654 | images = pipeline( | 661 | images = pipeline( |
@@ -716,7 +723,10 @@ def main(): | |||
716 | # We need to initialize the trackers we use, and also store our configuration. | 723 | # We need to initialize the trackers we use, and also store our configuration. |
717 | # The trackers initializes automatically on the main process. | 724 | # The trackers initializes automatically on the main process. |
718 | if accelerator.is_main_process: | 725 | if accelerator.is_main_process: |
719 | accelerator.init_trackers("textual_inversion", config=vars(args)) | 726 | config = vars(args).copy() |
727 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
728 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
729 | accelerator.init_trackers("textual_inversion", config=config) | ||
720 | 730 | ||
721 | # Train! | 731 | # Train! |
722 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 732 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |