diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 10:26:17 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:26:17 +0100 |
| commit | 4ef37d87d5a04bc6bb7dacee0660bba3057cc02f (patch) | |
| tree | ffcae0352c1b0784b18572f66bba3c617165cef1 | |
| parent | Various cleanups (diff) | |
| download | textual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.tar.gz textual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.tar.bz2 textual-inversion-diff-4ef37d87d5a04bc6bb7dacee0660bba3057cc02f.zip | |
Fix
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_ti.py | 19 |
2 files changed, 13 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c658ad6..1fd86b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -612,7 +612,7 @@ def main(): | |||
| 612 | raise ValueError("--embeddings_dir must point to an existing directory") | 612 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 613 | 613 | ||
| 614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 616 | 616 | ||
| 617 | if len(args.placeholder_token) != 0: | 617 | if len(args.placeholder_token) != 0: |
| 618 | # Convert the initializer_token, placeholder_token to ids | 618 | # Convert the initializer_token, placeholder_token to ids |
| @@ -629,7 +629,7 @@ def main(): | |||
| 629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) |
| 630 | ] | 630 | ] |
| 631 | 631 | ||
| 632 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") | 632 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
| 633 | else: | 633 | else: |
| 634 | placeholder_token_id = [] | 634 | placeholder_token_id = [] |
| 635 | 635 | ||
diff --git a/train_ti.py b/train_ti.py index 5df6850..164cf67 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -450,7 +450,8 @@ class Checkpointer(CheckpointerBase): | |||
| 450 | tokenizer, | 450 | tokenizer, |
| 451 | text_encoder, | 451 | text_encoder, |
| 452 | scheduler, | 452 | scheduler, |
| 453 | new_tokens, | 453 | placeholder_token, |
| 454 | new_ids, | ||
| 454 | output_dir: Path, | 455 | output_dir: Path, |
| 455 | sample_image_size, | 456 | sample_image_size, |
| 456 | sample_batches, | 457 | sample_batches, |
| @@ -473,7 +474,8 @@ class Checkpointer(CheckpointerBase): | |||
| 473 | self.tokenizer = tokenizer | 474 | self.tokenizer = tokenizer |
| 474 | self.text_encoder = text_encoder | 475 | self.text_encoder = text_encoder |
| 475 | self.scheduler = scheduler | 476 | self.scheduler = scheduler |
| 476 | self.new_tokens = new_tokens | 477 | self.placeholder_token = placeholder_token |
| 478 | self.new_ids = new_ids | ||
| 477 | 479 | ||
| 478 | @torch.no_grad() | 480 | @torch.no_grad() |
| 479 | def checkpoint(self, step, postfix): | 481 | def checkpoint(self, step, postfix): |
| @@ -484,10 +486,10 @@ class Checkpointer(CheckpointerBase): | |||
| 484 | 486 | ||
| 485 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 486 | 488 | ||
| 487 | for new_token in self.new_tokens: | 489 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
| 488 | text_encoder.text_model.embeddings.save_embed( | 490 | text_encoder.text_model.embeddings.save_embed( |
| 489 | new_token.ids, | 491 | ids, |
| 490 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") | 492 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 491 | ) | 493 | ) |
| 492 | 494 | ||
| 493 | del text_encoder | 495 | del text_encoder |
| @@ -572,7 +574,7 @@ def main(): | |||
| 572 | raise ValueError("--embeddings_dir must point to an existing directory") | 574 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 573 | 575 | ||
| 574 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 576 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 575 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 577 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 576 | 578 | ||
| 577 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
| 578 | initializer_token_ids = [ | 580 | initializer_token_ids = [ |
| @@ -588,7 +590,7 @@ def main(): | |||
| 588 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 590 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) |
| 589 | ] | 591 | ] |
| 590 | 592 | ||
| 591 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") | 593 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
| 592 | 594 | ||
| 593 | vae.requires_grad_(False) | 595 | vae.requires_grad_(False) |
| 594 | unet.requires_grad_(False) | 596 | unet.requires_grad_(False) |
| @@ -882,7 +884,8 @@ def main(): | |||
| 882 | tokenizer=tokenizer, | 884 | tokenizer=tokenizer, |
| 883 | text_encoder=text_encoder, | 885 | text_encoder=text_encoder, |
| 884 | scheduler=checkpoint_scheduler, | 886 | scheduler=checkpoint_scheduler, |
| 885 | new_tokens=new_tokens, | 887 | placeholder_token=args.placeholder_token, |
| 888 | new_ids=new_ids, | ||
| 886 | output_dir=basepath, | 889 | output_dir=basepath, |
| 887 | sample_image_size=args.sample_image_size, | 890 | sample_image_size=args.sample_image_size, |
| 888 | sample_batch_size=args.sample_batch_size, | 891 | sample_batch_size=args.sample_batch_size, |
