From 0bc909409648a3cae0061c3de2b39e486473ae39 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Oct 2022 17:57:05 +0200 Subject: Added CLI arg to set dataloader worker num; improved text encoder handling with Dreambooth --- data/csv.py | 10 +++++++--- dreambooth.py | 41 +++++++++++++++++++++++++++++++++-------- textual_inversion.py | 10 ++++++++++ 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/data/csv.py b/data/csv.py index f9b5e39..6bd7f9b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule): center_crop: bool = False, valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, - collate_fn=None + collate_fn=None, + num_workers: int = 0 ): super().__init__() @@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule): self.valid_set_size = valid_set_size self.generator = generator self.collate_fn = collate_fn + self.num_workers = num_workers self.batch_size = batch_size def prepare_subdata(self, template, data, num_class_images=1): @@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule): size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn) + shuffle=True, pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn) + pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers) def train_dataloader(self): return self.train_dataloader_ diff --git a/dreambooth.py b/dreambooth.py index db097e5..e71b7f0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -72,14 +72,22 @@ def parse_args(): "--placeholder_token", type=str, nargs='*', + default=[], help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', + default=[], help="A token to use as initializer word." ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + default=True, + help="Whether to train the whole text encoder." + ) parser.add_argument( "--num_class_images", type=int, @@ -118,6 +126,15 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) parser.add_argument( "--num_train_epochs", type=int, @@ -323,7 +340,7 @@ def parse_args(): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))] if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("Number of items in --placeholder_token and --initializer_token must match") @@ -391,6 +408,9 @@ class Checkpointer: @torch.no_grad() def save_embedding(self, step, postfix): + if len(self.placeholder_token) == 0: + return + print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") @@ -406,9 +426,6 @@ class Checkpointer: filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) - del unwrapped - del learned_embeds - @torch.no_grad() def save_model(self): print("Saving model...") @@ -575,7 +592,9 @@ def main(): # Freeze text_encoder and vae freeze_params(vae.parameters()) - if len(args.initializer_token) != 0: + if len(args.placeholder_token) != 0: + print(f"Adding text embeddings: {args.placeholder_token}") + # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -597,14 +616,19 @@ def main(): for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): token_embeds[token_id] = embeddings + else: + placeholder_token_id = [] + + if args.train_text_encoder: + print(f"Training entire text encoder.") + else: + print(f"Training added text embeddings") freeze_params(itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), )) - else: - placeholder_token_id = [] prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -700,6 +724,7 @@ def main(): repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, + num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) @@ -906,7 +931,7 @@ def main(): accelerator.backward(loss) - if args.initializer_token is not None: + if not args.train_text_encoder: # Keep the token embeddings fixed except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings if accelerator.num_processes > 1: diff --git a/textual_inversion.py b/textual_inversion.py index dd7c3bd..115f3aa 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -116,6 +116,15 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution" ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) parser.add_argument( "--num_train_epochs", type=int, @@ -626,6 +635,7 @@ def main(): repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, + num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) -- cgit v1.2.3-54-g00ecf