summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py58
1 files changed, 34 insertions, 24 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index c42762f..bcdfd3a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -70,13 +70,13 @@ def parse_args():
70 parser.add_argument( 70 parser.add_argument(
71 "--placeholder_token", 71 "--placeholder_token",
72 type=str, 72 type=str,
73 default="<*>", 73 nargs='*',
74 help="A token to use as a placeholder for the concept.", 74 help="A token to use as a placeholder for the concept.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--initializer_token", 77 "--initializer_token",
78 type=str, 78 type=str,
79 default=None, 79 nargs='*',
80 help="A token to use as initializer word." 80 help="A token to use as initializer word."
81 ) 81 )
82 parser.add_argument( 82 parser.add_argument(
@@ -299,12 +299,21 @@ def parse_args():
299 if args.pretrained_model_name_or_path is None: 299 if args.pretrained_model_name_or_path is None:
300 raise ValueError("You must specify --pretrained_model_name_or_path") 300 raise ValueError("You must specify --pretrained_model_name_or_path")
301 301
302 if args.placeholder_token is None: 302 if isinstance(args.initializer_token, str):
303 raise ValueError("You must specify --placeholder_token") 303 args.initializer_token = [args.initializer_token]
304 304
305 if args.initializer_token is None: 305 if len(args.initializer_token) == 0:
306 raise ValueError("You must specify --initializer_token") 306 raise ValueError("You must specify --initializer_token")
307 307
308 if isinstance(args.placeholder_token, str):
309 args.placeholder_token = [args.placeholder_token]
310
311 if len(args.placeholder_token) == 0:
312 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
313
314 if len(args.placeholder_token) != len(args.initializer_token):
315 raise ValueError("You must specify --placeholder_token")
316
308 if args.output_dir is None: 317 if args.output_dir is None:
309 raise ValueError("You must specify --output_dir") 318 raise ValueError("You must specify --output_dir")
310 319
@@ -373,12 +382,13 @@ class Checkpointer:
373 382
374 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 383 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
375 384
376 # Save a checkpoint 385 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
377 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 386 # Save a checkpoint
378 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 387 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id]
388 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
379 389
380 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 390 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
381 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 391 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
382 392
383 del unwrapped 393 del unwrapped
384 del learned_embeds 394 del learned_embeds
@@ -422,7 +432,7 @@ class Checkpointer:
422 432
423 for i in range(self.sample_batches): 433 for i in range(self.sample_batches):
424 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 434 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
425 prompt = [prompt.format(self.instance_identifier) 435 prompt = [prompt.format(identifier=self.instance_identifier)
426 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 436 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
427 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 437 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
428 438
@@ -498,16 +508,13 @@ def main():
498 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 508 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
499 509
500 # Convert the initializer_token, placeholder_token to ids 510 # Convert the initializer_token, placeholder_token to ids
501 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 511 initializer_token_ids = torch.stack([
502 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") 512 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
503 initializer_token_ids = torch.tensor(initializer_token_ids[:1]) 513 for token in args.initializer_token
514 ])
504 515
505 # Add the placeholder token in tokenizer
506 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 516 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
507 if num_added_tokens == 0: 517 print(f"Added {num_added_tokens} new tokens.")
508 print(f"Re-using existing token {args.placeholder_token}.")
509 else:
510 print(f"Training new token {args.placeholder_token}.")
511 518
512 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 519 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
513 520
@@ -533,11 +540,11 @@ def main():
533 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 540 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
534 541
535 if args.resume_checkpoint is not None: 542 if args.resume_checkpoint is not None:
536 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ 543 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
537 args.placeholder_token]
538 else: 544 else:
539 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 545 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
540 token_embeds[placeholder_token_id] = initializer_token_embeddings 546 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
547 token_embeds[token_id] = embeddings
541 548
542 # Freeze vae and unet 549 # Freeze vae and unet
543 freeze_params(vae.parameters()) 550 freeze_params(vae.parameters())
@@ -648,7 +655,7 @@ def main():
648 with torch.inference_mode(): 655 with torch.inference_mode():
649 for batch in batched_data: 656 for batch in batched_data:
650 image_name = [p.class_image_path for p in batch] 657 image_name = [p.class_image_path for p in batch]
651 prompt = [p.prompt.format(args.class_identifier) for p in batch] 658 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch]
652 nprompt = [p.nprompt for p in batch] 659 nprompt = [p.nprompt for p in batch]
653 660
654 images = pipeline( 661 images = pipeline(
@@ -716,7 +723,10 @@ def main():
716 # We need to initialize the trackers we use, and also store our configuration. 723 # We need to initialize the trackers we use, and also store our configuration.
717 # The trackers initializes automatically on the main process. 724 # The trackers initializes automatically on the main process.
718 if accelerator.is_main_process: 725 if accelerator.is_main_process:
719 accelerator.init_trackers("textual_inversion", config=vars(args)) 726 config = vars(args).copy()
727 config["initializer_token"] = " ".join(config["initializer_token"])
728 config["placeholder_token"] = " ".join(config["placeholder_token"])
729 accelerator.init_trackers("textual_inversion", config=config)
720 730
721 # Train! 731 # Train!
722 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 732 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps