diff options
author | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-14 09:43:45 +0100 |
commit | 279174a7a31f0fc6ed209e5b46901e50fe722c87 (patch) | |
tree | ec12ec9a66c5e6532aa0be08608c638283e090fb | |
parent | Unified loading of TI embeddings (diff) | |
download | textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.gz textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.tar.bz2 textual-inversion-diff-279174a7a31f0fc6ed209e5b46901e50fe722c87.zip |
More generic datset filter
-rw-r--r-- | common.py | 4 | ||||
-rw-r--r-- | data/csv.py | 10 | ||||
-rw-r--r-- | dreambooth.py | 6 | ||||
-rw-r--r-- | infer.py | 5 | ||||
-rw-r--r-- | textual_inversion.py | 13 |
5 files changed, 25 insertions, 13 deletions
@@ -18,7 +18,7 @@ def load_text_embedding(embeddings, token_id, file): | |||
18 | 18 | ||
19 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | 19 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): |
20 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 20 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
21 | return 0 | 21 | return [] |
22 | 22 | ||
23 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | 23 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] |
24 | 24 | ||
@@ -33,4 +33,4 @@ def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, | |||
33 | for (token_id, file) in zip(token_ids, files): | 33 | for (token_id, file) in zip(token_ids, files): |
34 | load_text_embedding(token_embeds, token_id, file) | 34 | load_text_embedding(token_embeds, token_id, file) |
35 | 35 | ||
36 | return added | 36 | return tokens |
diff --git a/data/csv.py b/data/csv.py index 9c3c3f8..20ac992 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -7,7 +7,7 @@ import pytorch_lightning as pl | |||
7 | from PIL import Image | 7 | from PIL import Image |
8 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
9 | from torchvision import transforms | 9 | from torchvision import transforms |
10 | from typing import Dict, NamedTuple, List, Optional, Union | 10 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
11 | 11 | ||
12 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
13 | 13 | ||
@@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
57 | template_key: str = "template", | 57 | template_key: str = "template", |
58 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
59 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
60 | keyword_filter: list[str] = [], | 60 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
61 | collate_fn=None, | 61 | collate_fn=None, |
62 | num_workers: int = 0 | 62 | num_workers: int = 0 |
63 | ): | 63 | ): |
@@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
84 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
85 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
86 | self.generator = generator | 86 | self.generator = generator |
87 | self.keyword_filter = keyword_filter | 87 | self.filter = filter |
88 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
89 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
90 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
@@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
105 | ] | 105 | ] |
106 | 106 | ||
107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: |
108 | if len(self.keyword_filter) == 0: | 108 | if self.filter is None: |
109 | return items | 109 | return items |
110 | 110 | ||
111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | 111 | return [item for item in items if self.filter(item)] |
112 | 112 | ||
113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: |
114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) |
diff --git a/dreambooth.py b/dreambooth.py index 3f45754..96213d0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -629,7 +629,11 @@ def main(): | |||
629 | vae.requires_grad_(False) | 629 | vae.requires_grad_(False) |
630 | 630 | ||
631 | if args.embeddings_dir is not None: | 631 | if args.embeddings_dir is not None: |
632 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 632 | embeddings_dir = Path(args.embeddings_dir) |
633 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
634 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
635 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
636 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
633 | 637 | ||
634 | if len(args.placeholder_token) != 0: | 638 | if len(args.placeholder_token) != 0: |
635 | # Convert the initializer_token, placeholder_token to ids | 639 | # Convert the initializer_token, placeholder_token to ids |
@@ -181,7 +181,7 @@ def save_args(basepath, args, extra={}): | |||
181 | json.dump(info, f, indent=4) | 181 | json.dump(info, f, indent=4) |
182 | 182 | ||
183 | 183 | ||
184 | def create_pipeline(model, ti_embeddings_dir, dtype): | 184 | def create_pipeline(model, embeddings_dir, dtype): |
185 | print("Loading Stable Diffusion pipeline...") | 185 | print("Loading Stable Diffusion pipeline...") |
186 | 186 | ||
187 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 187 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
@@ -190,7 +190,8 @@ def create_pipeline(model, ti_embeddings_dir, dtype): | |||
190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
192 | 192 | ||
193 | load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) | 193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) |
194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
194 | 195 | ||
195 | pipeline = VlpnStableDiffusion( | 196 | pipeline = VlpnStableDiffusion( |
196 | text_encoder=text_encoder, | 197 | text_encoder=text_encoder, |
diff --git a/textual_inversion.py b/textual_inversion.py index 6d8fd77..a849d2a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,7 +25,7 @@ from slugify import slugify | |||
25 | from common import load_text_embeddings, load_text_embedding | 25 | from common import load_text_embeddings, load_text_embedding |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | 27 | from pipelines.util import set_use_memory_efficient_attention_xformers |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule, CSVDataItem |
29 | from training.optimization import get_one_cycle_schedule | 29 | from training.optimization import get_one_cycle_schedule |
30 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
31 | 31 | ||
@@ -559,7 +559,11 @@ def main(): | |||
559 | text_encoder.gradient_checkpointing_enable() | 559 | text_encoder.gradient_checkpointing_enable() |
560 | 560 | ||
561 | if args.embeddings_dir is not None: | 561 | if args.embeddings_dir is not None: |
562 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | 562 | embeddings_dir = Path(args.embeddings_dir) |
563 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
564 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
565 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
566 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | ||
563 | 567 | ||
564 | # Convert the initializer_token, placeholder_token to ids | 568 | # Convert the initializer_token, placeholder_token to ids |
565 | initializer_token_ids = torch.stack([ | 569 | initializer_token_ids = torch.stack([ |
@@ -637,6 +641,9 @@ def main(): | |||
637 | elif args.mixed_precision == "bf16": | 641 | elif args.mixed_precision == "bf16": |
638 | weight_dtype = torch.bfloat16 | 642 | weight_dtype = torch.bfloat16 |
639 | 643 | ||
644 | def keyword_filter(item: CSVDataItem): | ||
645 | return any(keyword in item.prompt for keyword in args.placeholder_token) | ||
646 | |||
640 | def collate_fn(examples): | 647 | def collate_fn(examples): |
641 | prompts = [example["prompts"] for example in examples] | 648 | prompts = [example["prompts"] for example in examples] |
642 | nprompts = [example["nprompts"] for example in examples] | 649 | nprompts = [example["nprompts"] for example in examples] |
@@ -677,7 +684,7 @@ def main(): | |||
677 | template_key=args.train_data_template, | 684 | template_key=args.train_data_template, |
678 | valid_set_size=args.valid_set_size, | 685 | valid_set_size=args.valid_set_size, |
679 | num_workers=args.dataloader_num_workers, | 686 | num_workers=args.dataloader_num_workers, |
680 | keyword_filter=args.placeholder_token, | 687 | filter=keyword_filter, |
681 | collate_fn=collate_fn | 688 | collate_fn=collate_fn |
682 | ) | 689 | ) |
683 | 690 | ||