summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
commit03303d3bddba5a27a123babdf90863e27501e6f8 (patch)
tree8266c50f8e474d92ad4b42773cb8eb7730cd24c1 /textual_inversion.py
parentOptimized Textual Inversion training by filtering dataset by existence of add... (diff)
downloadtextual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.gz
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.bz2
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.zip
Unified loading of TI embeddings
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index fd4a313..6d8fd77 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -22,6 +22,7 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from common import load_text_embeddings, load_text_embedding
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers 27from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 28from data.csv import CSVDataModule
@@ -105,6 +106,12 @@ def parse_args():
105 help="The output directory where the model predictions and checkpoints will be written.", 106 help="The output directory where the model predictions and checkpoints will be written.",
106 ) 107 )
107 parser.add_argument( 108 parser.add_argument(
109 "--embeddings_dir",
110 type=str,
111 default=None,
112 help="The embeddings directory where Textual Inversion embeddings are stored.",
113 )
114 parser.add_argument(
108 "--seed", 115 "--seed",
109 type=int, 116 type=int,
110 default=None, 117 default=None,
@@ -551,6 +558,9 @@ def main():
551 unet.enable_gradient_checkpointing() 558 unet.enable_gradient_checkpointing()
552 text_encoder.gradient_checkpointing_enable() 559 text_encoder.gradient_checkpointing_enable()
553 560
561 if args.embeddings_dir is not None:
562 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir))
563
554 # Convert the initializer_token, placeholder_token to ids 564 # Convert the initializer_token, placeholder_token to ids
555 initializer_token_ids = torch.stack([ 565 initializer_token_ids = torch.stack([
556 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 566 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -562,10 +572,6 @@ def main():
562 572
563 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 573 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
564 574
565 print(f"Token ID mappings:")
566 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
567 print(f"- {token_id} {token}")
568
569 # Resize the token embeddings as we are adding new special tokens to the tokenizer 575 # Resize the token embeddings as we are adding new special tokens to the tokenizer
570 text_encoder.resize_token_embeddings(len(tokenizer)) 576 text_encoder.resize_token_embeddings(len(tokenizer))
571 577
@@ -576,14 +582,7 @@ def main():
576 resumepath = Path(args.resume_from).joinpath("checkpoints") 582 resumepath = Path(args.resume_from).joinpath("checkpoints")
577 583
578 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 584 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
579 embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") 585 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
580 embedding_data = torch.load(embedding_file, map_location="cpu")
581
582 emb = next(iter(embedding_data.values()))
583 if len(emb.shape) == 1:
584 emb = emb.unsqueeze(0)
585
586 token_embeds[token_id] = emb
587 586
588 original_token_embeds = token_embeds.clone().to(accelerator.device) 587 original_token_embeds = token_embeds.clone().to(accelerator.device)
589 588