From 279174a7a31f0fc6ed209e5b46901e50fe722c87 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 14 Dec 2022 09:43:45 +0100 Subject: More generic datset filter --- textual_inversion.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 6d8fd77..a849d2a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -25,7 +25,7 @@ from slugify import slugify from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers -from data.csv import CSVDataModule +from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -559,7 +559,11 @@ def main(): text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: - load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ @@ -637,6 +641,9 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + def keyword_filter(item: CSVDataItem): + return any(keyword in item.prompt for keyword in args.placeholder_token) + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -677,7 +684,7 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, - keyword_filter=args.placeholder_token, + filter=keyword_filter, collate_fn=collate_fn ) -- cgit v1.2.3-54-g00ecf