From 4ef37d87d5a04bc6bb7dacee0660bba3057cc02f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 10:26:17 +0100 Subject: Fix --- train_dreambooth.py | 4 ++-- train_ti.py | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index c658ad6..1fd86b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -612,7 +612,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids @@ -629,7 +629,7 @@ def main(): for (new_id, init_ids) in zip(new_ids, initializer_token_ids) ] - print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") + print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") else: placeholder_token_id = [] diff --git a/train_ti.py b/train_ti.py index 5df6850..164cf67 100644 --- a/train_ti.py +++ b/train_ti.py @@ -450,7 +450,8 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, - new_tokens, + placeholder_token, + new_ids, output_dir: Path, sample_image_size, sample_batches, @@ -473,7 +474,8 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler - self.new_tokens = new_tokens + self.placeholder_token = placeholder_token + self.new_ids = new_ids @torch.no_grad() def checkpoint(self, step, postfix): @@ -484,10 +486,10 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - for new_token in self.new_tokens: + for (token, ids) in zip(self.placeholder_token, self.new_ids): text_encoder.text_model.embeddings.save_embed( - new_token.ids, - checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) del text_encoder @@ -572,7 +574,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") # Convert the initializer_token, placeholder_token to ids initializer_token_ids = [ @@ -588,7 +590,7 @@ def main(): for (new_id, init_ids) in zip(new_ids, initializer_token_ids) ] - print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") + print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") vae.requires_grad_(False) unet.requires_grad_(False) @@ -882,7 +884,8 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, - new_tokens=new_tokens, + placeholder_token=args.placeholder_token, + new_ids=new_ids, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, -- cgit v1.2.3-70-g09d2