diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index fd4a313..6d8fd77 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -22,6 +22,7 @@ from tqdm.auto import tqdm | |||
22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | 24 | ||
25 | from common import load_text_embeddings, load_text_embedding | ||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from pipelines.util import set_use_memory_efficient_attention_xformers | 27 | from pipelines.util import set_use_memory_efficient_attention_xformers |
27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
@@ -105,6 +106,12 @@ def parse_args(): | |||
105 | help="The output directory where the model predictions and checkpoints will be written.", | 106 | help="The output directory where the model predictions and checkpoints will be written.", |
106 | ) | 107 | ) |
107 | parser.add_argument( | 108 | parser.add_argument( |
109 | "--embeddings_dir", | ||
110 | type=str, | ||
111 | default=None, | ||
112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
113 | ) | ||
114 | parser.add_argument( | ||
108 | "--seed", | 115 | "--seed", |
109 | type=int, | 116 | type=int, |
110 | default=None, | 117 | default=None, |
@@ -551,6 +558,9 @@ def main(): | |||
551 | unet.enable_gradient_checkpointing() | 558 | unet.enable_gradient_checkpointing() |
552 | text_encoder.gradient_checkpointing_enable() | 559 | text_encoder.gradient_checkpointing_enable() |
553 | 560 | ||
561 | if args.embeddings_dir is not None: | ||
562 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | ||
563 | |||
554 | # Convert the initializer_token, placeholder_token to ids | 564 | # Convert the initializer_token, placeholder_token to ids |
555 | initializer_token_ids = torch.stack([ | 565 | initializer_token_ids = torch.stack([ |
556 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 566 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -562,10 +572,6 @@ def main(): | |||
562 | 572 | ||
563 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 573 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
564 | 574 | ||
565 | print(f"Token ID mappings:") | ||
566 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
567 | print(f"- {token_id} {token}") | ||
568 | |||
569 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 575 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
570 | text_encoder.resize_token_embeddings(len(tokenizer)) | 576 | text_encoder.resize_token_embeddings(len(tokenizer)) |
571 | 577 | ||
@@ -576,14 +582,7 @@ def main(): | |||
576 | resumepath = Path(args.resume_from).joinpath("checkpoints") | 582 | resumepath = Path(args.resume_from).joinpath("checkpoints") |
577 | 583 | ||
578 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 584 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
579 | embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") | 585 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) |
580 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
581 | |||
582 | emb = next(iter(embedding_data.values())) | ||
583 | if len(emb.shape) == 1: | ||
584 | emb = emb.unsqueeze(0) | ||
585 | |||
586 | token_embeds[token_id] = emb | ||
587 | 586 | ||
588 | original_token_embeds = token_embeds.clone().to(accelerator.device) | 587 | original_token_embeds = token_embeds.clone().to(accelerator.device) |
589 | 588 | ||