summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-24 23:46:18 +0200
committerVolpeon <git@volpeon.ink>2022-10-24 23:46:18 +0200
commitbaba91864a45939cef4f77f6ca96ade7ae5ef274 (patch)
treec40fc949a94d5a2bee81b2b505b814e7c7f82cc1 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.gz
textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.bz2
textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.zip
Advanced datasets
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py68
1 files changed, 42 insertions, 26 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 5c26f12..2c24908 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -71,13 +71,13 @@ def parse_args():
71 parser.add_argument( 71 parser.add_argument(
72 "--placeholder_token", 72 "--placeholder_token",
73 type=str, 73 type=str,
74 default="<*>", 74 nargs='*',
75 help="A token to use as a placeholder for the concept.", 75 help="A token to use as a placeholder for the concept.",
76 ) 76 )
77 parser.add_argument( 77 parser.add_argument(
78 "--initializer_token", 78 "--initializer_token",
79 type=str, 79 type=str,
80 default=None, 80 nargs='*',
81 help="A token to use as initializer word." 81 help="A token to use as initializer word."
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
@@ -316,6 +316,18 @@ def parse_args():
316 if args.instance_identifier is None: 316 if args.instance_identifier is None:
317 raise ValueError("You must specify --instance_identifier") 317 raise ValueError("You must specify --instance_identifier")
318 318
319 if isinstance(args.initializer_token, str):
320 args.initializer_token = [args.initializer_token]
321
322 if isinstance(args.placeholder_token, str):
323 args.placeholder_token = [args.placeholder_token]
324
325 if len(args.placeholder_token) == 0:
326 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
327
328 if len(args.placeholder_token) != len(args.initializer_token):
329 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
330
319 if args.output_dir is None: 331 if args.output_dir is None:
320 raise ValueError("You must specify --output_dir") 332 raise ValueError("You must specify --output_dir")
321 333
@@ -379,9 +391,6 @@ class Checkpointer:
379 391
380 @torch.no_grad() 392 @torch.no_grad()
381 def save_embedding(self, step, postfix): 393 def save_embedding(self, step, postfix):
382 if self.placeholder_token_id is None:
383 return
384
385 print("Saving checkpoint for step %d..." % step) 394 print("Saving checkpoint for step %d..." % step)
386 395
387 checkpoints_path = self.output_dir.joinpath("checkpoints") 396 checkpoints_path = self.output_dir.joinpath("checkpoints")
@@ -389,12 +398,13 @@ class Checkpointer:
389 398
390 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 399 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
391 400
392 # Save a checkpoint 401 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
393 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 402 # Save a checkpoint
394 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 403 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id]
404 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
395 405
396 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 406 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
397 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 407 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
398 408
399 del unwrapped 409 del unwrapped
400 del learned_embeds 410 del learned_embeds
@@ -467,7 +477,7 @@ class Checkpointer:
467 for i in range(self.sample_batches): 477 for i in range(self.sample_batches):
468 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 478 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
469 prompt = [ 479 prompt = [
470 prompt.format(self.instance_identifier) 480 prompt.format(identifier=self.instance_identifier)
471 for batch in batches 481 for batch in batches
472 for prompt in batch["prompts"] 482 for prompt in batch["prompts"]
473 ][:self.sample_batch_size] 483 ][:self.sample_batch_size]
@@ -516,8 +526,8 @@ def main():
516 526
517 instance_identifier = args.instance_identifier 527 instance_identifier = args.instance_identifier
518 528
519 if args.placeholder_token is not None: 529 if len(args.placeholder_token) != 0:
520 instance_identifier = instance_identifier.format(args.placeholder_token) 530 instance_identifier = instance_identifier.format(args.placeholder_token[0])
521 531
522 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 532 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
523 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 533 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
@@ -565,18 +575,16 @@ def main():
565 # Freeze text_encoder and vae 575 # Freeze text_encoder and vae
566 freeze_params(vae.parameters()) 576 freeze_params(vae.parameters())
567 577
568 if args.initializer_token is not None: 578 if len(args.initializer_token) != 0:
569 # Convert the initializer_token, placeholder_token to ids 579 # Convert the initializer_token, placeholder_token to ids
570 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 580 initializer_token_ids = torch.stack([
571 print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") 581 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
572 initializer_token_ids = torch.tensor(initializer_token_ids[:1]) 582 for token in args.initializer_token
583 ])
573 584
574 # Add the placeholder token in tokenizer
575 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 585 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
576 if num_added_tokens == 0: 586 print(f"Added {num_added_tokens} new tokens.")
577 print(f"Re-using existing token {args.placeholder_token}.") 587
578 else:
579 print(f"Training new token {args.placeholder_token}.")
580 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 588 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
581 589
582 # Resize the token embeddings as we are adding new special tokens to the tokenizer 590 # Resize the token embeddings as we are adding new special tokens to the tokenizer
@@ -586,7 +594,9 @@ def main():
586 token_embeds = text_encoder.get_input_embeddings().weight.data 594 token_embeds = text_encoder.get_input_embeddings().weight.data
587 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 595 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
588 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 596 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
589 token_embeds[placeholder_token_id] = initializer_token_embeddings 597
598 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
599 token_embeds[token_id] = embeddings
590 600
591 freeze_params(itertools.chain( 601 freeze_params(itertools.chain(
592 text_encoder.text_model.encoder.parameters(), 602 text_encoder.text_model.encoder.parameters(),
@@ -594,7 +604,7 @@ def main():
594 text_encoder.text_model.embeddings.position_embedding.parameters(), 604 text_encoder.text_model.embeddings.position_embedding.parameters(),
595 )) 605 ))
596 else: 606 else:
597 placeholder_token_id = None 607 placeholder_token_id = []
598 608
599 prompt_processor = PromptProcessor(tokenizer, text_encoder) 609 prompt_processor = PromptProcessor(tokenizer, text_encoder)
600 610
@@ -721,7 +731,7 @@ def main():
721 with torch.inference_mode(): 731 with torch.inference_mode():
722 for batch in batched_data: 732 for batch in batched_data:
723 image_name = [item.class_image_path for item in batch] 733 image_name = [item.class_image_path for item in batch]
724 prompt = [item.prompt.format(args.class_identifier) for item in batch] 734 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
725 nprompt = [item.nprompt for item in batch] 735 nprompt = [item.nprompt for item in batch]
726 736
727 images = pipeline( 737 images = pipeline(
@@ -787,7 +797,10 @@ def main():
787 # We need to initialize the trackers we use, and also store our configuration. 797 # We need to initialize the trackers we use, and also store our configuration.
788 # The trackers initializes automatically on the main process. 798 # The trackers initializes automatically on the main process.
789 if accelerator.is_main_process: 799 if accelerator.is_main_process:
790 accelerator.init_trackers("dreambooth", config=vars(args)) 800 config = vars(args).copy()
801 config["initializer_token"] = " ".join(config["initializer_token"])
802 config["placeholder_token"] = " ".join(config["placeholder_token"])
803 accelerator.init_trackers("dreambooth", config=config)
791 804
792 # Train! 805 # Train!
793 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 806 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -932,6 +945,9 @@ def main():
932 global_step += 1 945 global_step += 1
933 946
934 if global_step % args.sample_frequency == 0: 947 if global_step % args.sample_frequency == 0:
948 local_progress_bar.clear()
949 global_progress_bar.clear()
950
935 checkpointer.save_embedding(global_step, "training") 951 checkpointer.save_embedding(global_step, "training")
936 sample_checkpoint = True 952 sample_checkpoint = True
937 953