summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.py4
-rw-r--r--data/csv.py10
-rw-r--r--dreambooth.py6
-rw-r--r--infer.py5
-rw-r--r--textual_inversion.py13
5 files changed, 25 insertions, 13 deletions
diff --git a/common.py b/common.py
index 8d6b55d..7ffa77f 100644
--- a/common.py
+++ b/common.py
@@ -18,7 +18,7 @@ def load_text_embedding(embeddings, token_id, file):
18 18
19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): 19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
20 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 20 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
21 return 0 21 return []
22 22
23 files = [file for file in embeddings_dir.iterdir() if file.is_file()] 23 files = [file for file in embeddings_dir.iterdir() if file.is_file()]
24 24
@@ -33,4 +33,4 @@ def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel,
33 for (token_id, file) in zip(token_ids, files): 33 for (token_id, file) in zip(token_ids, files):
34 load_text_embedding(token_embeds, token_id, file) 34 load_text_embedding(token_embeds, token_id, file)
35 35
36 return added 36 return tokens
diff --git a/data/csv.py b/data/csv.py
index 9c3c3f8..20ac992 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
7from PIL import Image 7from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split 8from torch.utils.data import Dataset, DataLoader, random_split
9from torchvision import transforms 9from torchvision import transforms
10from typing import Dict, NamedTuple, List, Optional, Union 10from typing import Dict, NamedTuple, List, Optional, Union, Callable
11 11
12from models.clip.prompt import PromptProcessor 12from models.clip.prompt import PromptProcessor
13 13
@@ -57,7 +57,7 @@ class CSVDataModule(pl.LightningDataModule):
57 template_key: str = "template", 57 template_key: str = "template",
58 valid_set_size: Optional[int] = None, 58 valid_set_size: Optional[int] = None,
59 generator: Optional[torch.Generator] = None, 59 generator: Optional[torch.Generator] = None,
60 keyword_filter: list[str] = [], 60 filter: Optional[Callable[[CSVDataItem], bool]] = None,
61 collate_fn=None, 61 collate_fn=None,
62 num_workers: int = 0 62 num_workers: int = 0
63 ): 63 ):
@@ -84,7 +84,7 @@ class CSVDataModule(pl.LightningDataModule):
84 self.interpolation = interpolation 84 self.interpolation = interpolation
85 self.valid_set_size = valid_set_size 85 self.valid_set_size = valid_set_size
86 self.generator = generator 86 self.generator = generator
87 self.keyword_filter = keyword_filter 87 self.filter = filter
88 self.collate_fn = collate_fn 88 self.collate_fn = collate_fn
89 self.num_workers = num_workers 89 self.num_workers = num_workers
90 self.batch_size = batch_size 90 self.batch_size = batch_size
@@ -105,10 +105,10 @@ class CSVDataModule(pl.LightningDataModule):
105 ] 105 ]
106 106
107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: 107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]:
108 if len(self.keyword_filter) == 0: 108 if self.filter is None:
109 return items 109 return items
110 110
111 return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] 111 return [item for item in items if self.filter(item)]
112 112
113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1) 114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1)
diff --git a/dreambooth.py b/dreambooth.py
index 3f45754..96213d0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -629,7 +629,11 @@ def main():
629 vae.requires_grad_(False) 629 vae.requires_grad_(False)
630 630
631 if args.embeddings_dir is not None: 631 if args.embeddings_dir is not None:
632 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) 632 embeddings_dir = Path(args.embeddings_dir)
633 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
634 raise ValueError("--embeddings_dir must point to an existing directory")
635 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
636 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
633 637
634 if len(args.placeholder_token) != 0: 638 if len(args.placeholder_token) != 0:
635 # Convert the initializer_token, placeholder_token to ids 639 # Convert the initializer_token, placeholder_token to ids
diff --git a/infer.py b/infer.py
index 1fd11e2..efeb24d 100644
--- a/infer.py
+++ b/infer.py
@@ -181,7 +181,7 @@ def save_args(basepath, args, extra={}):
181 json.dump(info, f, indent=4) 181 json.dump(info, f, indent=4)
182 182
183 183
184def create_pipeline(model, ti_embeddings_dir, dtype): 184def create_pipeline(model, embeddings_dir, dtype):
185 print("Loading Stable Diffusion pipeline...") 185 print("Loading Stable Diffusion pipeline...")
186 186
187 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 187 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -190,7 +190,8 @@ def create_pipeline(model, ti_embeddings_dir, dtype):
190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
192 192
193 load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) 193 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
194 195
195 pipeline = VlpnStableDiffusion( 196 pipeline = VlpnStableDiffusion(
196 text_encoder=text_encoder, 197 text_encoder=text_encoder,
diff --git a/textual_inversion.py b/textual_inversion.py
index 6d8fd77..a849d2a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,7 +25,7 @@ from slugify import slugify
25from common import load_text_embeddings, load_text_embedding 25from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers 27from pipelines.util import set_use_memory_efficient_attention_xformers
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule, CSVDataItem
29from training.optimization import get_one_cycle_schedule 29from training.optimization import get_one_cycle_schedule
30from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31 31
@@ -559,7 +559,11 @@ def main():
559 text_encoder.gradient_checkpointing_enable() 559 text_encoder.gradient_checkpointing_enable()
560 560
561 if args.embeddings_dir is not None: 561 if args.embeddings_dir is not None:
562 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) 562 embeddings_dir = Path(args.embeddings_dir)
563 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
564 raise ValueError("--embeddings_dir must point to an existing directory")
565 added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
566 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
563 567
564 # Convert the initializer_token, placeholder_token to ids 568 # Convert the initializer_token, placeholder_token to ids
565 initializer_token_ids = torch.stack([ 569 initializer_token_ids = torch.stack([
@@ -637,6 +641,9 @@ def main():
637 elif args.mixed_precision == "bf16": 641 elif args.mixed_precision == "bf16":
638 weight_dtype = torch.bfloat16 642 weight_dtype = torch.bfloat16
639 643
644 def keyword_filter(item: CSVDataItem):
645 return any(keyword in item.prompt for keyword in args.placeholder_token)
646
640 def collate_fn(examples): 647 def collate_fn(examples):
641 prompts = [example["prompts"] for example in examples] 648 prompts = [example["prompts"] for example in examples]
642 nprompts = [example["nprompts"] for example in examples] 649 nprompts = [example["nprompts"] for example in examples]
@@ -677,7 +684,7 @@ def main():
677 template_key=args.train_data_template, 684 template_key=args.train_data_template,
678 valid_set_size=args.valid_set_size, 685 valid_set_size=args.valid_set_size,
679 num_workers=args.dataloader_num_workers, 686 num_workers=args.dataloader_num_workers,
680 keyword_filter=args.placeholder_token, 687 filter=keyword_filter,
681 collate_fn=collate_fn 688 collate_fn=collate_fn
682 ) 689 )
683 690