From 0bc909409648a3cae0061c3de2b39e486473ae39 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 17:57:05 +0200 Subject: Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth --- dreambooth.py | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -72,14 +72,22 @@ def parse_args(): "--placeholder_token", type=str, nargs='*', + default=[], help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', + default=[], help="A token to use as initializer word." ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + default=True, + help="Whether to train the whole text encoder." + ) parser.add_argument( "--num_class_images", type=int, @@ -118,6 +126,15 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) parser.add_argument( "--num_train_epochs", type=int, @@ -323,7 +340,7 @@ def parse_args(): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("Number of items in --placeholder_token and --initializer_token must match") @@ -391,6 +408,9 @@ class Checkpointer: @torch.no_grad() def save_embedding(self, step, postfix): + if len(self.placeholder_token) == 0: + return + print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") @@ -406,9 +426,6 @@ class Checkpointer: filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) - del unwrapped - del learned_embeds - @torch.no_grad() def save_model(self): print("Saving model...") @@ -575,7 +592,9 @@ def main(): # Freeze text_encoder and vae freeze_params(vae.parameters()) - if len(args.initializer_token) != 0: + if len(args.placeholder_token) != 0: + print(f"Adding text embeddings: {args.placeholder_token}") + # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -597,14 +616,19 @@ def main(): for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): token_embeds[token_id] = embeddings + else: + placeholder_token_id = [] + + if args.train_text_encoder: + print(f"Training entire text encoder.") + else: + print(f"Training added text embeddings") freeze_params(itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), )) - else: - placeholder_token_id = [] prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -700,6 +724,7 @@ def main(): repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, + num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) @@ -906,7 +931,7 @@ def main(): accelerator.backward(loss) - if args.initializer_token is not None: + if not args.train_text_encoder: # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: -- cgit v1.2.3-54-g00ecf