From 279174a7a31f0fc6ed209e5b46901e50fe722c87 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 14 Dec 2022 09:43:45 +0100 Subject: More generic datset filter --- common.py | 4 ++-- data/csv.py | 10 +++++----- dreambooth.py | 6 +++++- infer.py | 5 +++-- textual_inversion.py | 13 ++++++++++--- 5 files changed, 25 insertions(+), 13 deletions(-) diff --git a/common.py b/common.py index 8d6b55d..7ffa77f 100644 --- a/common.py +++ b/common.py @@ -18,7 +18,7 @@ def load_text_embedding(embeddings, token_id, file): def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - return 0 + return [] files = [file for file in embeddings_dir.iterdir() if file.is_file()] @@ -33,4 +33,4 @@ def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, for (token_id, file) in zip(token_ids, files): load_text_embedding(token_embeds, token_id, file) - return added + return tokens diff --git a/data/csv.py b/data/csv.py index 9c3c3f8..20ac992 100644 --- a/data/csv.py +++ b/data/csv.py @@ -7,7 +7,7 @@ import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import Dict, NamedTuple, List, Optional, Union +from typing import Dict, NamedTuple, List, Optional, Union, Callable from models.clip.prompt import PromptProcessor @@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule): template_key: str = "template", valid_set_size: Optional[int] = None, generator: Optional[torch.Generator] = None, - keyword_filter: list[str] = [], + filter: Optional[Callable[[CSVDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): @@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule): self.interpolation = interpolation self.valid_set_size = valid_set_size self.generator = generator - self.keyword_filter = keyword_filter + self.filter = filter self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size @@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule): ] def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: - if len(self.keyword_filter) == 0: + if self.filter is None: return items - return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] + return [item for item in items if self.filter(item)] def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: image_multiplier = max(math.ceil(num_class_images / len(items)), 1) diff --git a/dreambooth.py b/dreambooth.py index 3f45754..96213d0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -629,7 +629,11 @@ def main(): vae.requires_grad_(False) if args.embeddings_dir is not None: - load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") if len(args.placeholder_token) != 0: # Convert the initializer_token, placeholder_token to ids diff --git a/infer.py b/infer.py index 1fd11e2..efeb24d 100644 --- a/infer.py +++ b/infer.py @@ -181,7 +181,7 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def create_pipeline(model, ti_embeddings_dir, dtype): +def create_pipeline(model, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) @@ -190,7 +190,8 @@ def create_pipeline(model, ti_embeddings_dir, dtype): unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) + added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") pipeline = VlpnStableDiffusion( text_encoder=text_encoder, diff --git a/textual_inversion.py b/textual_inversion.py index 6d8fd77..a849d2a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -25,7 +25,7 @@ from slugify import slugify from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from pipelines.util import set_use_memory_efficient_attention_xformers -from data.csv import CSVDataModule +from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -559,7 +559,11 @@ def main(): text_encoder.gradient_checkpointing_enable() if args.embeddings_dir is not None: - load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ @@ -637,6 +641,9 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + def keyword_filter(item: CSVDataItem): + return any(keyword in item.prompt for keyword in args.placeholder_token) + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -677,7 +684,7 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, - keyword_filter=args.placeholder_token, + filter=keyword_filter, collate_fn=collate_fn ) -- cgit v1.2.3-70-g09d2