summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-27 17:57:05 +0200
committerVolpeon <git@volpeon.ink>2022-10-27 17:57:05 +0200
commit0bc909409648a3cae0061c3de2b39e486473ae39 (patch)
tree5fdbcd7c56919293963c3c8b53bdb2099834079d /dreambooth.py
parentEuler_a: Re-introduce generator arg for reproducible output (diff)
downloadtextual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.gz
textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.tar.bz2
textual-inversion-diff-0bc909409648a3cae0061c3de2b39e486473ae39.zip
Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py41
1 files changed, 33 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index db097e5..e71b7f0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -72,15 +72,23 @@ def parse_args():
72 "--placeholder_token", 72 "--placeholder_token",
73 type=str, 73 type=str,
74 nargs='*', 74 nargs='*',
75 default=[],
75 help="A token to use as a placeholder for the concept.", 76 help="A token to use as a placeholder for the concept.",
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
78 "--initializer_token", 79 "--initializer_token",
79 type=str, 80 type=str,
80 nargs='*', 81 nargs='*',
82 default=[],
81 help="A token to use as initializer word." 83 help="A token to use as initializer word."
82 ) 84 )
83 parser.add_argument( 85 parser.add_argument(
86 "--train_text_encoder",
87 action="store_true",
88 default=True,
89 help="Whether to train the whole text encoder."
90 )
91 parser.add_argument(
84 "--num_class_images", 92 "--num_class_images",
85 type=int, 93 type=int,
86 default=400, 94 default=400,
@@ -119,6 +127,15 @@ def parse_args():
119 help="Whether to center crop images before resizing to resolution" 127 help="Whether to center crop images before resizing to resolution"
120 ) 128 )
121 parser.add_argument( 129 parser.add_argument(
130 "--dataloader_num_workers",
131 type=int,
132 default=0,
133 help=(
134 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
135 " process."
136 ),
137 )
138 parser.add_argument(
122 "--num_train_epochs", 139 "--num_train_epochs",
123 type=int, 140 type=int,
124 default=100 141 default=100
@@ -323,7 +340,7 @@ def parse_args():
323 args.placeholder_token = [args.placeholder_token] 340 args.placeholder_token = [args.placeholder_token]
324 341
325 if len(args.placeholder_token) == 0: 342 if len(args.placeholder_token) == 0:
326 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 343 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
327 344
328 if len(args.placeholder_token) != len(args.initializer_token): 345 if len(args.placeholder_token) != len(args.initializer_token):
329 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 346 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
@@ -391,6 +408,9 @@ class Checkpointer:
391 408
392 @torch.no_grad() 409 @torch.no_grad()
393 def save_embedding(self, step, postfix): 410 def save_embedding(self, step, postfix):
411 if len(self.placeholder_token) == 0:
412 return
413
394 print("Saving checkpoint for step %d..." % step) 414 print("Saving checkpoint for step %d..." % step)
395 415
396 checkpoints_path = self.output_dir.joinpath("checkpoints") 416 checkpoints_path = self.output_dir.joinpath("checkpoints")
@@ -406,9 +426,6 @@ class Checkpointer:
406 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 426 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
407 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 427 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
408 428
409 del unwrapped
410 del learned_embeds
411
412 @torch.no_grad() 429 @torch.no_grad()
413 def save_model(self): 430 def save_model(self):
414 print("Saving model...") 431 print("Saving model...")
@@ -575,7 +592,9 @@ def main():
575 # Freeze text_encoder and vae 592 # Freeze text_encoder and vae
576 freeze_params(vae.parameters()) 593 freeze_params(vae.parameters())
577 594
578 if len(args.initializer_token) != 0: 595 if len(args.placeholder_token) != 0:
596 print(f"Adding text embeddings: {args.placeholder_token}")
597
579 # Convert the initializer_token, placeholder_token to ids 598 # Convert the initializer_token, placeholder_token to ids
580 initializer_token_ids = torch.stack([ 599 initializer_token_ids = torch.stack([
581 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 600 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -597,14 +616,19 @@ def main():
597 616
598 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 617 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
599 token_embeds[token_id] = embeddings 618 token_embeds[token_id] = embeddings
619 else:
620 placeholder_token_id = []
621
622 if args.train_text_encoder:
623 print(f"Training entire text encoder.")
624 else:
625 print(f"Training added text embeddings")
600 626
601 freeze_params(itertools.chain( 627 freeze_params(itertools.chain(
602 text_encoder.text_model.encoder.parameters(), 628 text_encoder.text_model.encoder.parameters(),
603 text_encoder.text_model.final_layer_norm.parameters(), 629 text_encoder.text_model.final_layer_norm.parameters(),
604 text_encoder.text_model.embeddings.position_embedding.parameters(), 630 text_encoder.text_model.embeddings.position_embedding.parameters(),
605 )) 631 ))
606 else:
607 placeholder_token_id = []
608 632
609 prompt_processor = PromptProcessor(tokenizer, text_encoder) 633 prompt_processor = PromptProcessor(tokenizer, text_encoder)
610 634
@@ -700,6 +724,7 @@ def main():
700 repeats=args.repeats, 724 repeats=args.repeats,
701 center_crop=args.center_crop, 725 center_crop=args.center_crop,
702 valid_set_size=args.sample_batch_size*args.sample_batches, 726 valid_set_size=args.sample_batch_size*args.sample_batches,
727 num_workers=args.dataloader_num_workers,
703 collate_fn=collate_fn 728 collate_fn=collate_fn
704 ) 729 )
705 730
@@ -906,7 +931,7 @@ def main():
906 931
907 accelerator.backward(loss) 932 accelerator.backward(loss)
908 933
909 if args.initializer_token is not None: 934 if not args.train_text_encoder:
910 # Keep the token embeddings fixed except the newly added 935 # Keep the token embeddings fixed except the newly added
911 # embeddings for the concept, as we only want to optimize the concept embeddings 936 # embeddings for the concept, as we only want to optimize the concept embeddings
912 if accelerator.num_processes > 1: 937 if accelerator.num_processes > 1: