summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py26
1 files changed, 5 insertions, 21 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 5521b21..3f45754 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -21,6 +21,7 @@ from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 22from slugify import slugify
23 23
24from common import load_text_embeddings
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from pipelines.util import set_use_memory_efficient_attention_xformers 26from pipelines.util import set_use_memory_efficient_attention_xformers
26from data.csv import CSVDataModule 27from data.csv import CSVDataModule
@@ -125,7 +126,7 @@ def parse_args():
125 parser.add_argument( 126 parser.add_argument(
126 "--embeddings_dir", 127 "--embeddings_dir",
127 type=str, 128 type=str,
128 default="embeddings_ti", 129 default=None,
129 help="The embeddings directory where Textual Inversion embeddings are stored.", 130 help="The embeddings directory where Textual Inversion embeddings are stored.",
130 ) 131 )
131 parser.add_argument( 132 parser.add_argument(
@@ -578,8 +579,6 @@ def main():
578 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 579 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
579 basepath.mkdir(parents=True, exist_ok=True) 580 basepath.mkdir(parents=True, exist_ok=True)
580 581
581 embeddings_dir = Path(args.embeddings_dir)
582
583 accelerator = Accelerator( 582 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD, 583 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}", 584 logging_dir=f"{basepath}",
@@ -629,6 +628,9 @@ def main():
629 # Freeze text_encoder and vae 628 # Freeze text_encoder and vae
630 vae.requires_grad_(False) 629 vae.requires_grad_(False)
631 630
631 if args.embeddings_dir is not None:
632 load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir))
633
632 if len(args.placeholder_token) != 0: 634 if len(args.placeholder_token) != 0:
633 # Convert the initializer_token, placeholder_token to ids 635 # Convert the initializer_token, placeholder_token to ids
634 initializer_token_ids = torch.stack([ 636 initializer_token_ids = torch.stack([
@@ -645,24 +647,6 @@ def main():
645 text_encoder.resize_token_embeddings(len(tokenizer)) 647 text_encoder.resize_token_embeddings(len(tokenizer))
646 648
647 token_embeds = text_encoder.get_input_embeddings().weight.data 649 token_embeds = text_encoder.get_input_embeddings().weight.data
648
649 print(f"Token ID mappings:")
650 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
651 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
652 embedding_source = "init"
653
654 if embedding_file.exists() and embedding_file.is_file():
655 embedding_data = torch.load(embedding_file, map_location="cpu")
656
657 emb = next(iter(embedding_data.values()))
658 if len(emb.shape) == 1:
659 emb = emb.unsqueeze(0)
660
661 token_embeds[token_id] = emb
662 embedding_source = "file"
663
664 print(f"- {token_id} {token} ({embedding_source})")
665
666 original_token_embeds = token_embeds.clone().to(accelerator.device) 650 original_token_embeds = token_embeds.clone().to(accelerator.device)
667 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 651 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
668 652