From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- infer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 3487e5a..34e570a 100644 --- a/infer.py +++ b/infer.py @@ -171,6 +171,18 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): embeddings_dir = Path(embeddings_dir) embeddings_dir.mkdir(parents=True, exist_ok=True) + for file in embeddings_dir.iterdir(): + placeholder_token = file.stem + + num_added_tokens = tokenizer.add_tokens(placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data for file in embeddings_dir.iterdir(): @@ -187,6 +199,8 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): token_embeds[placeholder_token_id] = emb + print(f"Loaded embedding: {placeholder_token}") + def create_pipeline(model, scheduler, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") -- cgit v1.2.3-54-g00ecf