diff options
-rw-r--r-- | train_dreambooth.py | 4 | ||||
-rw-r--r-- | train_ti.py | 19 |
2 files changed, 13 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c658ad6..1fd86b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -612,7 +612,7 @@ def main(): | |||
612 | raise ValueError("--embeddings_dir must point to an existing directory") | 612 | raise ValueError("--embeddings_dir must point to an existing directory") |
613 | 613 | ||
614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
616 | 616 | ||
617 | if len(args.placeholder_token) != 0: | 617 | if len(args.placeholder_token) != 0: |
618 | # Convert the initializer_token, placeholder_token to ids | 618 | # Convert the initializer_token, placeholder_token to ids |
@@ -629,7 +629,7 @@ def main(): | |||
629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) |
630 | ] | 630 | ] |
631 | 631 | ||
632 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") | 632 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
633 | else: | 633 | else: |
634 | placeholder_token_id = [] | 634 | placeholder_token_id = [] |
635 | 635 | ||
diff --git a/train_ti.py b/train_ti.py index 5df6850..164cf67 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -450,7 +450,8 @@ class Checkpointer(CheckpointerBase): | |||
450 | tokenizer, | 450 | tokenizer, |
451 | text_encoder, | 451 | text_encoder, |
452 | scheduler, | 452 | scheduler, |
453 | new_tokens, | 453 | placeholder_token, |
454 | new_ids, | ||
454 | output_dir: Path, | 455 | output_dir: Path, |
455 | sample_image_size, | 456 | sample_image_size, |
456 | sample_batches, | 457 | sample_batches, |
@@ -473,7 +474,8 @@ class Checkpointer(CheckpointerBase): | |||
473 | self.tokenizer = tokenizer | 474 | self.tokenizer = tokenizer |
474 | self.text_encoder = text_encoder | 475 | self.text_encoder = text_encoder |
475 | self.scheduler = scheduler | 476 | self.scheduler = scheduler |
476 | self.new_tokens = new_tokens | 477 | self.placeholder_token = placeholder_token |
478 | self.new_ids = new_ids | ||
477 | 479 | ||
478 | @torch.no_grad() | 480 | @torch.no_grad() |
479 | def checkpoint(self, step, postfix): | 481 | def checkpoint(self, step, postfix): |
@@ -484,10 +486,10 @@ class Checkpointer(CheckpointerBase): | |||
484 | 486 | ||
485 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
486 | 488 | ||
487 | for new_token in self.new_tokens: | 489 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
488 | text_encoder.text_model.embeddings.save_embed( | 490 | text_encoder.text_model.embeddings.save_embed( |
489 | new_token.ids, | 491 | ids, |
490 | checkpoints_path.joinpath(f"{slugify(new_token.token)}_{step}_{postfix}.bin") | 492 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
491 | ) | 493 | ) |
492 | 494 | ||
493 | del text_encoder | 495 | del text_encoder |
@@ -572,7 +574,7 @@ def main(): | |||
572 | raise ValueError("--embeddings_dir must point to an existing directory") | 574 | raise ValueError("--embeddings_dir must point to an existing directory") |
573 | 575 | ||
574 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 576 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
575 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 577 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
576 | 578 | ||
577 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
578 | initializer_token_ids = [ | 580 | initializer_token_ids = [ |
@@ -588,7 +590,7 @@ def main(): | |||
588 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | 590 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) |
589 | ] | 591 | ] |
590 | 592 | ||
591 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") | 593 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
592 | 594 | ||
593 | vae.requires_grad_(False) | 595 | vae.requires_grad_(False) |
594 | unet.requires_grad_(False) | 596 | unet.requires_grad_(False) |
@@ -882,7 +884,8 @@ def main(): | |||
882 | tokenizer=tokenizer, | 884 | tokenizer=tokenizer, |
883 | text_encoder=text_encoder, | 885 | text_encoder=text_encoder, |
884 | scheduler=checkpoint_scheduler, | 886 | scheduler=checkpoint_scheduler, |
885 | new_tokens=new_tokens, | 887 | placeholder_token=args.placeholder_token, |
888 | new_ids=new_ids, | ||
886 | output_dir=basepath, | 889 | output_dir=basepath, |
887 | sample_image_size=args.sample_image_size, | 890 | sample_image_size=args.sample_image_size, |
888 | sample_batch_size=args.sample_batch_size, | 891 | sample_batch_size=args.sample_batch_size, |