diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 26 |
1 files changed, 5 insertions, 21 deletions
diff --git a/dreambooth.py b/dreambooth.py index 5521b21..3f45754 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -21,6 +21,7 @@ from tqdm.auto import tqdm | |||
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 22 | from slugify import slugify |
| 23 | 23 | ||
| 24 | from common import load_text_embeddings | ||
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from pipelines.util import set_use_memory_efficient_attention_xformers | 26 | from pipelines.util import set_use_memory_efficient_attention_xformers |
| 26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| @@ -125,7 +126,7 @@ def parse_args(): | |||
| 125 | parser.add_argument( | 126 | parser.add_argument( |
| 126 | "--embeddings_dir", | 127 | "--embeddings_dir", |
| 127 | type=str, | 128 | type=str, |
| 128 | default="embeddings_ti", | 129 | default=None, |
| 129 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 130 | ) | 131 | ) |
| 131 | parser.add_argument( | 132 | parser.add_argument( |
| @@ -578,8 +579,6 @@ def main(): | |||
| 578 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 579 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 579 | basepath.mkdir(parents=True, exist_ok=True) | 580 | basepath.mkdir(parents=True, exist_ok=True) |
| 580 | 581 | ||
| 581 | embeddings_dir = Path(args.embeddings_dir) | ||
| 582 | |||
| 583 | accelerator = Accelerator( | 582 | accelerator = Accelerator( |
| 584 | log_with=LoggerType.TENSORBOARD, | 583 | log_with=LoggerType.TENSORBOARD, |
| 585 | logging_dir=f"{basepath}", | 584 | logging_dir=f"{basepath}", |
| @@ -629,6 +628,9 @@ def main(): | |||
| 629 | # Freeze text_encoder and vae | 628 | # Freeze text_encoder and vae |
| 630 | vae.requires_grad_(False) | 629 | vae.requires_grad_(False) |
| 631 | 630 | ||
| 631 | if args.embeddings_dir is not None: | ||
| 632 | load_text_embeddings(tokenizer, text_encoder, Path(args.embeddings_dir)) | ||
| 633 | |||
| 632 | if len(args.placeholder_token) != 0: | 634 | if len(args.placeholder_token) != 0: |
| 633 | # Convert the initializer_token, placeholder_token to ids | 635 | # Convert the initializer_token, placeholder_token to ids |
| 634 | initializer_token_ids = torch.stack([ | 636 | initializer_token_ids = torch.stack([ |
| @@ -645,24 +647,6 @@ def main(): | |||
| 645 | text_encoder.resize_token_embeddings(len(tokenizer)) | 647 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 646 | 648 | ||
| 647 | token_embeds = text_encoder.get_input_embeddings().weight.data | 649 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 648 | |||
| 649 | print(f"Token ID mappings:") | ||
| 650 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 651 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") | ||
| 652 | embedding_source = "init" | ||
| 653 | |||
| 654 | if embedding_file.exists() and embedding_file.is_file(): | ||
| 655 | embedding_data = torch.load(embedding_file, map_location="cpu") | ||
| 656 | |||
| 657 | emb = next(iter(embedding_data.values())) | ||
| 658 | if len(emb.shape) == 1: | ||
| 659 | emb = emb.unsqueeze(0) | ||
| 660 | |||
| 661 | token_embeds[token_id] = emb | ||
| 662 | embedding_source = "file" | ||
| 663 | |||
| 664 | print(f"- {token_id} {token} ({embedding_source})") | ||
| 665 | |||
| 666 | original_token_embeds = token_embeds.clone().to(accelerator.device) | 650 | original_token_embeds = token_embeds.clone().to(accelerator.device) |
| 667 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 651 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 668 | 652 | ||
