diff options
| -rw-r--r-- | common.py | 36 | ||||
| -rw-r--r-- | dreambooth.py | 26 | ||||
| -rw-r--r-- | infer.py | 34 | ||||
| -rw-r--r-- | textual_inversion.py | 23 |
4 files changed, 54 insertions, 65 deletions
diff --git a/common.py b/common.py new file mode 100644 index 0000000..8d6b55d --- /dev/null +++ b/common.py | |||
| @@ -0,0 +1,36 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | import torch | ||
| 3 | |||
| 4 | from transformers import CLIPTextModel, CLIPTokenizer | ||
| 5 | |||
| 6 | |||
| 7 | def load_text_embedding(embeddings, token_id, file): | ||
| 8 | data = torch.load(file, map_location="cpu") | ||
| 9 | |||
| 10 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
| 11 | |||
| 12 | emb = next(iter(data.values())) | ||
| 13 | if len(emb.shape) == 1: | ||
| 14 | emb = emb.unsqueeze(0) | ||
| 15 | |||
| 16 | embeddings[token_id] = emb | ||
| 17 | |||
| 18 | |||
| 19 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
| 20 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 21 | return 0 | ||
| 22 | |||
| 23 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | ||
| 24 | |||
| 25 | tokens = [file.stem for file in files] | ||
| 26 | added = tokenizer.add_tokens(tokens) | ||
| 27 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
| 28 | |||
| 29 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 30 | |||
| 31 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 32 | |||
| 33 | for (token_id, file) in zip(token_ids, files): | ||
| 34 | load_text_embedding(token_embeds, token_id, file) | ||
| 35 | |||
| 36 | return added | ||
diff --git a/dreambooth.py b/dreambooth.py index 5521b21..3f45754 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -21,6 +21,7 @@ from tqdm.auto import tqdm | |||
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 22 | from slugify import slugify |
| 23 | 23 | ||
| 24 | from common import load_text_embeddings | ||
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from pipelines.util import set_use_memory_efficient_attention_xformers | 26 | from pipelines.util import set_use_memory_efficient_attention_xformers |
| 26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| @@ -125,7 +126,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 126 | parser.add_argument( |
| 126 | "--embeddings_dir", | 127 | "--embeddings_dir", |
| 127 | type=str, | 128 | type=str, |
| 128 | default="embeddings_ti", | 129 | default=None, |
| 129 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 130 | ) | 131 | ) |
| 131 | parser.add_argument( | 132 | parser.add_argument( |
| @@ -578,8 +579,6 @@ def main(): | |||
| 578 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 579 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 579 | basepath.mkdir(parents=True, exist_ok=True) | 580 | basepath.mkdir(parents=True, exist_ok=True) |
| 580 | 581 | ||
| 581 | embeddings_dir = Path(args.embeddings_dir) | ||
| 582 | |||
| 583 | accelerator = Accelerator( | 582 | accelerator = Accelerator( |
| 584 | log_with=LoggerType.TENSORBOARD, | 583 | log_with=LoggerType.TENSORBOARD, |
| 585 | logging_dir=f"{basepath}", | 584 | logging_dir=f"{basepath}", |
| @@ -629,6 +628,9 @@ def main(): | |||
| 629 | # Freeze text_encoder and vae | 628 | # Freeze text_encoder and vae |
| 630 | vae.requires_grad_(False) | 629 | vae.requires_grad_(False) |
| 631 | 630 | ||
| 631 | if args.embeddings_dir is not None: | ||
| 632 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | ||
| 633 | |||
| 632 | if len(args.placeholder_token) != 0: | 634 | if len(args.placeholder_token) != 0: |
| 633 | # Convert the initializer_token, placeholder_token to ids | 635 | # Convert the initializer_token, placeholder_token to ids |
| 634 | initializer_token_ids = torch.stack([ | 636 | initializer_token_ids = torch.stack([ |
| @@ -645,24 +647,6 @@ def main(): | |||
| 645 | text_encoder.resize_token_embeddings(len(tokenizer)) | 647 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 646 | 648 | ||
| 647 | token_embeds = text_encoder.get_input_embeddings().weight.data | 649 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 648 | |||
| 649 | print(f"Token ID mappings:") | ||
| 650 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 651 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") | ||
| 652 | embedding_source = "init" | ||
| 653 | |||
| 654 | if embedding_file.exists() and embedding_file.is_file(): | ||
| 655 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
| 656 | |||
| 657 | emb = next(iter(embedding_data.values())) | ||
| 658 | if len(emb.shape) == 1: | ||
| 659 | emb = emb.unsqueeze(0) | ||
| 660 | |||
| 661 | token_embeds[token_id] = emb | ||
| 662 | embedding_source = "file" | ||
| 663 | |||
| 664 | print(f"- {token_id} {token} ({embedding_source})") | ||
| 665 | |||
| 666 | original_token_embeds = token_embeds.clone().to(accelerator.device) | 650 | original_token_embeds = token_embeds.clone().to(accelerator.device) |
| 667 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 651 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 668 | 652 | ||
| @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from common import load_text_embeddings | ||
| 27 | 28 | ||
| 28 | 29 | ||
| 29 | torch.backends.cuda.matmul.allow_tf32 = True | 30 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}): | |||
| 180 | json.dump(info, f, indent=4) | 181 | json.dump(info, f, indent=4) |
| 181 | 182 | ||
| 182 | 183 | ||
| 183 | def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | ||
| 184 | print(f"Loading Textual Inversion embeddings") | ||
| 185 | |||
| 186 | embeddings_dir = Path(embeddings_dir) | ||
| 187 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
| 188 | |||
| 189 | placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] | ||
| 190 | tokenizer.add_tokens(placeholder_tokens) | ||
| 191 | |||
| 192 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 193 | |||
| 194 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 195 | |||
| 196 | for file in embeddings_dir.iterdir(): | ||
| 197 | if file.is_file(): | ||
| 198 | placeholder_token = file.stem | ||
| 199 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
| 200 | |||
| 201 | data = torch.load(file, map_location="cpu") | ||
| 202 | |||
| 203 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
| 204 | |||
| 205 | emb = next(iter(data.values())) | ||
| 206 | if len(emb.shape) == 1: | ||
| 207 | emb = emb.unsqueeze(0) | ||
| 208 | |||
| 209 | token_embeds[placeholder_token_id] = emb | ||
| 210 | |||
| 211 | print(f"Loaded {placeholder_token}") | ||
| 212 | |||
| 213 | |||
| 214 | def create_pipeline(model, ti_embeddings_dir, dtype): | 184 | def create_pipeline(model, ti_embeddings_dir, dtype): |
| 215 | print("Loading Stable Diffusion pipeline...") | 185 | print("Loading Stable Diffusion pipeline...") |
| 216 | 186 | ||
| @@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): | |||
| 220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
| 222 | 192 | ||
| 223 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 193 | load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) |
| 224 | 194 | ||
| 225 | pipeline = VlpnStableDiffusion( | 195 | pipeline = VlpnStableDiffusion( |
| 226 | text_encoder=text_encoder, | 196 | text_encoder=text_encoder, |
diff --git a/textual_inversion.py b/textual_inversion.py index fd4a313..6d8fd77 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -22,6 +22,7 @@ from tqdm.auto import tqdm | |||
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from common import load_text_embeddings, load_text_embedding | ||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from pipelines.util import set_use_memory_efficient_attention_xformers | 27 | from pipelines.util import set_use_memory_efficient_attention_xformers |
| 27 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| @@ -105,6 +106,12 @@ def parse_args(): | |||
| 105 | help="The output directory where the model predictions and checkpoints will be written.", | 106 | help="The output directory where the model predictions and checkpoints will be written.", |
| 106 | ) | 107 | ) |
| 107 | parser.add_argument( | 108 | parser.add_argument( |
| 109 | "--embeddings_dir", | ||
| 110 | type=str, | ||
| 111 | default=None, | ||
| 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
| 113 | ) | ||
| 114 | parser.add_argument( | ||
| 108 | "--seed", | 115 | "--seed", |
| 109 | type=int, | 116 | type=int, |
| 110 | default=None, | 117 | default=None, |
| @@ -551,6 +558,9 @@ def main(): | |||
| 551 | unet.enable_gradient_checkpointing() | 558 | unet.enable_gradient_checkpointing() |
| 552 | text_encoder.gradient_checkpointing_enable() | 559 | text_encoder.gradient_checkpointing_enable() |
| 553 | 560 | ||
| 561 | if args.embeddings_dir is not None: | ||
| 562 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | ||
| 563 | |||
| 554 | # Convert the initializer_token, placeholder_token to ids | 564 | # Convert the initializer_token, placeholder_token to ids |
| 555 | initializer_token_ids = torch.stack([ | 565 | initializer_token_ids = torch.stack([ |
| 556 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 566 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| @@ -562,10 +572,6 @@ def main(): | |||
| 562 | 572 | ||
| 563 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 573 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 564 | 574 | ||
| 565 | print(f"Token ID mappings:") | ||
| 566 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 567 | print(f"- {token_id} {token}") | ||
| 568 | |||
| 569 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 575 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 570 | text_encoder.resize_token_embeddings(len(tokenizer)) | 576 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 571 | 577 | ||
| @@ -576,14 +582,7 @@ def main(): | |||
| 576 | resumepath = Path(args.resume_from).joinpath("checkpoints") | 582 | resumepath = Path(args.resume_from).joinpath("checkpoints") |
| 577 | 583 | ||
| 578 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 584 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
| 579 | embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") | 585 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) |
| 580 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
| 581 | |||
| 582 | emb = next(iter(embedding_data.values())) | ||
| 583 | if len(emb.shape) == 1: | ||
| 584 | emb = emb.unsqueeze(0) | ||
| 585 | |||
| 586 | token_embeds[token_id] = emb | ||
| 587 | 586 | ||
| 588 | original_token_embeds = token_embeds.clone().to(accelerator.device) | 587 | original_token_embeds = token_embeds.clone().to(accelerator.device) |
| 589 | 588 | ||
