diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 6d8fd77..a849d2a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,7 +25,7 @@ from slugify import slugify | |||
25 | from common import load_text_embeddings, load_text_embedding | 25 | from common import load_text_embeddings, load_text_embedding |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | 27 | from pipelines.util import set_use_memory_efficient_attention_xformers |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule, CSVDataItem |
29 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
30 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
31 | 31 | ||
@@ -559,7 +559,11 @@ def main(): | |||
559 | text_encoder.gradient_checkpointing_enable() | 559 | text_encoder.gradient_checkpointing_enable() |
560 | 560 | ||
561 | if args.embeddings_dir is not None: | 561 | if args.embeddings_dir is not None: |
562 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 562 | embeddings_dir = Path(args.embeddings_dir) |
563 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
564 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
565 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
566 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
563 | 567 | ||
564 | # Convert the initializer_token, placeholder_token to ids | 568 | # Convert the initializer_token, placeholder_token to ids |
565 | initializer_token_ids = torch.stack([ | 569 | initializer_token_ids = torch.stack([ |
@@ -637,6 +641,9 @@ def main(): | |||
637 | elif args.mixed_precision == "bf16": | 641 | elif args.mixed_precision == "bf16": |
638 | weight_dtype = torch.bfloat16 | 642 | weight_dtype = torch.bfloat16 |
639 | 643 | ||
644 | def keyword_filter(item: CSVDataItem): | ||
645 | return any(keyword in item.prompt for keyword in args.placeholder_token) | ||
646 | |||
640 | def collate_fn(examples): | 647 | def collate_fn(examples): |
641 | prompts = [example["prompts"] for example in examples] | 648 | prompts = [example["prompts"] for example in examples] |
642 | nprompts = [example["nprompts"] for example in examples] | 649 | nprompts = [example["nprompts"] for example in examples] |
@@ -677,7 +684,7 @@ def main(): | |||
677 | template_key=args.train_data_template, | 684 | template_key=args.train_data_template, |
678 | valid_set_size=args.valid_set_size, | 685 | valid_set_size=args.valid_set_size, |
679 | num_workers=args.dataloader_num_workers, | 686 | num_workers=args.dataloader_num_workers, |
680 | keyword_filter=args.placeholder_token, | 687 | filter=keyword_filter, |
681 | collate_fn=collate_fn | 688 | collate_fn=collate_fn |
682 | ) | 689 | ) |
683 | 690 | ||