summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-14 09:43:45 +0100
committerVolpeon <git@volpeon.ink>2022-12-14 09:43:45 +0100
commit279174a7a31f0fc6ed209e5b46901e50fe722c87 (patch)
treeec12ec9a66c5e6532aa0be08608c638283e090fb /textual_inversion.py
parentUnified loading of TI embeddings (diff)
downloadtextual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.gz
textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.bz2
textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.zip
More generic datset filter
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 6d8fd77..a849d2a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,7 +25,7 @@ from slugify import slugify
25from common import load_text_embeddings, load_text_embedding 25from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers 27from pipelines.util import set_use_memory_efficient_attention_xformers
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule, CSVDataItem
29from training.optimization import get_one_cycle_schedule 29from training.optimization import get_one_cycle_schedule
30from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31 31
@@ -559,7 +559,11 @@ def main():
559 text_encoder.gradient_checkpointing_enable() 559 text_encoder.gradient_checkpointing_enable()
560 560
561 if args.embeddings_dir is not None: 561 if args.embeddings_dir is not None:
562 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) 562 embeddings_dir = Path(args.embeddings_dir)
563 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
564 raise ValueError("--embeddings_dir must point to an existing directory")
565 added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
566 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
563 567
564 # Convert the initializer_token, placeholder_token to ids 568 # Convert the initializer_token, placeholder_token to ids
565 initializer_token_ids = torch.stack([ 569 initializer_token_ids = torch.stack([
@@ -637,6 +641,9 @@ def main():
637 elif args.mixed_precision == "bf16": 641 elif args.mixed_precision == "bf16":
638 weight_dtype = torch.bfloat16 642 weight_dtype = torch.bfloat16
639 643
644 def keyword_filter(item: CSVDataItem):
645 return any(keyword in item.prompt for keyword in args.placeholder_token)
646
640 def collate_fn(examples): 647 def collate_fn(examples):
641 prompts = [example["prompts"] for example in examples] 648 prompts = [example["prompts"] for example in examples]
642 nprompts = [example["nprompts"] for example in examples] 649 nprompts = [example["nprompts"] for example in examples]
@@ -677,7 +684,7 @@ def main():
677 template_key=args.train_data_template, 684 template_key=args.train_data_template,
678 valid_set_size=args.valid_set_size, 685 valid_set_size=args.valid_set_size,
679 num_workers=args.dataloader_num_workers, 686 num_workers=args.dataloader_num_workers,
680 keyword_filter=args.placeholder_token, 687 filter=keyword_filter,
681 collate_fn=collate_fn 688 collate_fn=collate_fn
682 ) 689 )
683 690