summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.py36
-rw-r--r--dreambooth.py26
-rw-r--r--infer.py34
-rw-r--r--textual_inversion.py23
4 files changed, 54 insertions, 65 deletions
diff --git a/common.py b/common.py
new file mode 100644
index 0000000..8d6b55d
--- /dev/null
+++ b/common.py
@@ -0,0 +1,36 @@
1from pathlib import Path
2import torch
3
4from transformers import CLIPTextModel, CLIPTokenizer
5
6
7def load_text_embedding(embeddings, token_id, file):
8 data = torch.load(file, map_location="cpu")
9
10 assert len(data.keys()) == 1, 'embedding data has multiple terms in it'
11
12 emb = next(iter(data.values()))
13 if len(emb.shape) == 1:
14 emb = emb.unsqueeze(0)
15
16 embeddings[token_id] = emb
17
18
19def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
20 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
21 return 0
22
23 files = [file for file in embeddings_dir.iterdir() if file.is_file()]
24
25 tokens = [file.stem for file in files]
26 added = tokenizer.add_tokens(tokens)
27 token_ids = tokenizer.convert_tokens_to_ids(tokens)
28
29 text_encoder.resize_token_embeddings(len(tokenizer))
30
31 token_embeds = text_encoder.get_input_embeddings().weight.data
32
33 for (token_id, file) in zip(token_ids, files):
34 load_text_embedding(token_embeds, token_id, file)
35
36 return added
diff --git a/dreambooth.py b/dreambooth.py
index 5521b21..3f45754 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -21,6 +21,7 @@ from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 22from slugify import slugify
23 23
24from common import load_text_embeddings
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from pipelines.util import set_use_memory_efficient_attention_xformers 26from pipelines.util import set_use_memory_efficient_attention_xformers
26from data.csv import CSVDataModule 27from data.csv import CSVDataModule
@@ -125,7 +126,7 @@ def parse_args():
125 parser.add_argument( 126 parser.add_argument(
126 "--embeddings_dir", 127 "--embeddings_dir",
127 type=str, 128 type=str,
128 default="embeddings_ti", 129 default=None,
129 help="The embeddings directory where Textual Inversion embeddings are stored.", 130 help="The embeddings directory where Textual Inversion embeddings are stored.",
130 ) 131 )
131 parser.add_argument( 132 parser.add_argument(
@@ -578,8 +579,6 @@ def main():
578 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 579 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
579 basepath.mkdir(parents=True, exist_ok=True) 580 basepath.mkdir(parents=True, exist_ok=True)
580 581
581 embeddings_dir = Path(args.embeddings_dir)
582
583 accelerator = Accelerator( 582 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD, 583 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}", 584 logging_dir=f"{basepath}",
@@ -629,6 +628,9 @@ def main():
629 # Freeze text_encoder and vae 628 # Freeze text_encoder and vae
630 vae.requires_grad_(False) 629 vae.requires_grad_(False)
631 630
631 if args.embeddings_dir is not None:
632 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir))
633
632 if len(args.placeholder_token) != 0: 634 if len(args.placeholder_token) != 0:
633 # Convert the initializer_token, placeholder_token to ids 635 # Convert the initializer_token, placeholder_token to ids
634 initializer_token_ids = torch.stack([ 636 initializer_token_ids = torch.stack([
@@ -645,24 +647,6 @@ def main():
645 text_encoder.resize_token_embeddings(len(tokenizer)) 647 text_encoder.resize_token_embeddings(len(tokenizer))
646 648
647 token_embeds = text_encoder.get_input_embeddings().weight.data 649 token_embeds = text_encoder.get_input_embeddings().weight.data
648
649 print(f"Token ID mappings:")
650 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
651 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
652 embedding_source = "init"
653
654 if embedding_file.exists() and embedding_file.is_file():
655 embedding_data = torch.load(embedding_file, map_location="cpu")
656
657 emb = next(iter(embedding_data.values()))
658 if len(emb.shape) == 1:
659 emb = emb.unsqueeze(0)
660
661 token_embeds[token_id] = emb
662 embedding_source = "file"
663
664 print(f"- {token_id} {token} ({embedding_source})")
665
666 original_token_embeds = token_embeds.clone().to(accelerator.device) 650 original_token_embeds = token_embeds.clone().to(accelerator.device)
667 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 651 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
668 652
diff --git a/infer.py b/infer.py
index f607041..1fd11e2 100644
--- a/infer.py
+++ b/infer.py
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings
27 28
28 29
29torch.backends.cuda.matmul.allow_tf32 = True 30torch.backends.cuda.matmul.allow_tf32 = True
@@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}):
180 json.dump(info, f, indent=4) 181 json.dump(info, f, indent=4)
181 182
182 183
183def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
184 print(f"Loading Textual Inversion embeddings")
185
186 embeddings_dir = Path(embeddings_dir)
187 embeddings_dir.mkdir(parents=True, exist_ok=True)
188
189 placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()]
190 tokenizer.add_tokens(placeholder_tokens)
191
192 text_encoder.resize_token_embeddings(len(tokenizer))
193
194 token_embeds = text_encoder.get_input_embeddings().weight.data
195
196 for file in embeddings_dir.iterdir():
197 if file.is_file():
198 placeholder_token = file.stem
199 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
200
201 data = torch.load(file, map_location="cpu")
202
203 assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
204
205 emb = next(iter(data.values()))
206 if len(emb.shape) == 1:
207 emb = emb.unsqueeze(0)
208
209 token_embeds[placeholder_token_id] = emb
210
211 print(f"Loaded {placeholder_token}")
212
213
214def create_pipeline(model, ti_embeddings_dir, dtype): 184def create_pipeline(model, ti_embeddings_dir, dtype):
215 print("Loading Stable Diffusion pipeline...") 185 print("Loading Stable Diffusion pipeline...")
216 186
@@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype):
220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222 192
223 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) 193 load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir))
224 194
225 pipeline = VlpnStableDiffusion( 195 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder, 196 text_encoder=text_encoder,
diff --git a/textual_inversion.py b/textual_inversion.py
index fd4a313..6d8fd77 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -22,6 +22,7 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from common import load_text_embeddings, load_text_embedding
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers 27from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 28from data.csv import CSVDataModule
@@ -105,6 +106,12 @@ def parse_args():
105 help="The output directory where the model predictions and checkpoints will be written.", 106 help="The output directory where the model predictions and checkpoints will be written.",
106 ) 107 )
107 parser.add_argument( 108 parser.add_argument(
109 "--embeddings_dir",
110 type=str,
111 default=None,
112 help="The embeddings directory where Textual Inversion embeddings are stored.",
113 )
114 parser.add_argument(
108 "--seed", 115 "--seed",
109 type=int, 116 type=int,
110 default=None, 117 default=None,
@@ -551,6 +558,9 @@ def main():
551 unet.enable_gradient_checkpointing() 558 unet.enable_gradient_checkpointing()
552 text_encoder.gradient_checkpointing_enable() 559 text_encoder.gradient_checkpointing_enable()
553 560
561 if args.embeddings_dir is not None:
562 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir))
563
554 # Convert the initializer_token, placeholder_token to ids 564 # Convert the initializer_token, placeholder_token to ids
555 initializer_token_ids = torch.stack([ 565 initializer_token_ids = torch.stack([
556 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 566 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -562,10 +572,6 @@ def main():
562 572
563 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 573 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
564 574
565 print(f"Token ID mappings:")
566 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
567 print(f"- {token_id} {token}")
568
569 # Resize the token embeddings as we are adding new special tokens to the tokenizer 575 # Resize the token embeddings as we are adding new special tokens to the tokenizer
570 text_encoder.resize_token_embeddings(len(tokenizer)) 576 text_encoder.resize_token_embeddings(len(tokenizer))
571 577
@@ -576,14 +582,7 @@ def main():
576 resumepath = Path(args.resume_from).joinpath("checkpoints") 582 resumepath = Path(args.resume_from).joinpath("checkpoints")
577 583
578 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 584 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
579 embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") 585 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
580 embedding_data = torch.load(embedding_file, map_location="cpu")
581
582 emb = next(iter(embedding_data.values()))
583 if len(emb.shape) == 1:
584 emb = emb.unsqueeze(0)
585
586 token_embeds[token_id] = emb
587 586
588 original_token_embeds = token_embeds.clone().to(accelerator.device) 587 original_token_embeds = token_embeds.clone().to(accelerator.device)
589 588