From 03303d3bddba5a27a123babdf90863e27501e6f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 23:09:25 +0100 Subject: Unified loading of TI embeddings --- dreambooth.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 5521b21..3f45754 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -21,6 +21,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify +from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule @@ -125,7 +126,7 @@ def parse_args(): parser.add_argument( "--embeddings_dir", type=str, - default="embeddings_ti", + default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( @@ -578,8 +579,6 @@ def main(): basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) - embeddings_dir = Path(args.embeddings_dir) - accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, logging_dir=f"{basepath}", @@ -629,6 +628,9 @@ def main(): # Freeze text_encoder and vae vae.requires_grad_(False) + if args.embeddings_dir is not None: + load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ @@ -645,24 +647,6 @@ def main(): text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data - - print(f"Token ID mappings:") - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - embedding_file = embeddings_dir.joinpath(f"{token}.bin") - embedding_source = "init" - - if embedding_file.exists() and embedding_file.is_file(): - embedding_data = torch.load(embedding_file, map_location="cpu") - - emb = next(iter(embedding_data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - token_embeds[token_id] = emb - embedding_source = "file" - - print(f"- {token_id} {token} ({embedding_source})") - original_token_embeds = token_embeds.clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) -- cgit v1.2.3-54-g00ecf