diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
| commit | 279174a7a31f0fc6ed209e5b46901e50fe722c87 (patch) | |
| tree | ec12ec9a66c5e6532aa0be08608c638283e090fb /textual_inversion.py | |
| parent | Unified loading of TI embeddings (diff) | |
| download | textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.gz textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.bz2 textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.zip | |
More generic datset filter
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 6d8fd77..a849d2a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -25,7 +25,7 @@ from slugify import slugify | |||
| 25 | from common import load_text_embeddings, load_text_embedding | 25 | from common import load_text_embeddings, load_text_embedding |
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from pipelines.util import set_use_memory_efficient_attention_xformers | 27 | from pipelines.util import set_use_memory_efficient_attention_xformers |
| 28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule, CSVDataItem |
| 29 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
| 30 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| 31 | 31 | ||
| @@ -559,7 +559,11 @@ def main(): | |||
| 559 | text_encoder.gradient_checkpointing_enable() | 559 | text_encoder.gradient_checkpointing_enable() |
| 560 | 560 | ||
| 561 | if args.embeddings_dir is not None: | 561 | if args.embeddings_dir is not None: |
| 562 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 562 | embeddings_dir = Path(args.embeddings_dir) |
| 563 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 564 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 565 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
| 566 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
| 563 | 567 | ||
| 564 | # Convert the initializer_token, placeholder_token to ids | 568 | # Convert the initializer_token, placeholder_token to ids |
| 565 | initializer_token_ids = torch.stack([ | 569 | initializer_token_ids = torch.stack([ |
| @@ -637,6 +641,9 @@ def main(): | |||
| 637 | elif args.mixed_precision == "bf16": | 641 | elif args.mixed_precision == "bf16": |
| 638 | weight_dtype = torch.bfloat16 | 642 | weight_dtype = torch.bfloat16 |
| 639 | 643 | ||
| 644 | def keyword_filter(item: CSVDataItem): | ||
| 645 | return any(keyword in item.prompt for keyword in args.placeholder_token) | ||
| 646 | |||
| 640 | def collate_fn(examples): | 647 | def collate_fn(examples): |
| 641 | prompts = [example["prompts"] for example in examples] | 648 | prompts = [example["prompts"] for example in examples] |
| 642 | nprompts = [example["nprompts"] for example in examples] | 649 | nprompts = [example["nprompts"] for example in examples] |
| @@ -677,7 +684,7 @@ def main(): | |||
| 677 | template_key=args.train_data_template, | 684 | template_key=args.train_data_template, |
| 678 | valid_set_size=args.valid_set_size, | 685 | valid_set_size=args.valid_set_size, |
| 679 | num_workers=args.dataloader_num_workers, | 686 | num_workers=args.dataloader_num_workers, |
| 680 | keyword_filter=args.placeholder_token, | 687 | filter=keyword_filter, |
| 681 | collate_fn=collate_fn | 688 | collate_fn=collate_fn |
| 682 | ) | 689 | ) |
| 683 | 690 | ||
