diff options
-rw-r--r-- | dreambooth.py | 4 | ||||
-rw-r--r-- | textual_inversion.py | 20 |
2 files changed, 10 insertions, 14 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9a6f70a..31416e9 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -625,8 +625,6 @@ def main(): | |||
625 | vae.requires_grad_(False) | 625 | vae.requires_grad_(False) |
626 | 626 | ||
627 | if len(args.placeholder_token) != 0: | 627 | if len(args.placeholder_token) != 0: |
628 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
629 | |||
630 | # Convert the initializer_token, placeholder_token to ids | 628 | # Convert the initializer_token, placeholder_token to ids |
631 | initializer_token_ids = torch.stack([ | 629 | initializer_token_ids = torch.stack([ |
632 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 630 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -1114,7 +1112,7 @@ def main(): | |||
1114 | max_acc_val = avg_acc_val.avg.item() | 1112 | max_acc_val = avg_acc_val.avg.item() |
1115 | 1113 | ||
1116 | if accelerator.is_main_process: | 1114 | if accelerator.is_main_process: |
1117 | if epoch % args.sample_frequency == 0: | 1115 | if (epoch + 1) % args.sample_frequency == 0: |
1118 | checkpointer.save_samples(global_step, args.sample_steps) | 1116 | checkpointer.save_samples(global_step, args.sample_steps) |
1119 | 1117 | ||
1120 | # Create the pipeline using using the trained modules and save it. | 1118 | # Create the pipeline using using the trained modules and save it. |
diff --git a/textual_inversion.py b/textual_inversion.py index 11babd8..19b8993 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -302,15 +302,11 @@ def parse_args(): | |||
302 | "--config", | 302 | "--config", |
303 | type=str, | 303 | type=str, |
304 | default=None, | 304 | default=None, |
305 | help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." | 305 | help="Path to a JSON configuration file containing arguments for invoking this script." |
306 | ) | 306 | ) |
307 | 307 | ||
308 | args = parser.parse_args() | 308 | args = parser.parse_args() |
309 | if args.resume_from is not None: | 309 | if args.config is not None: |
310 | with open(f"{args.resume_from}/resume.json", 'rt') as f: | ||
311 | args = parser.parse_args( | ||
312 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
313 | elif args.config is not None: | ||
314 | with open(args.config, 'rt') as f: | 310 | with open(args.config, 'rt') as f: |
315 | args = parser.parse_args( | 311 | args = parser.parse_args( |
316 | namespace=argparse.Namespace(**json.load(f)["args"])) | 312 | namespace=argparse.Namespace(**json.load(f)["args"])) |
@@ -550,8 +546,6 @@ def main(): | |||
550 | unet.enable_gradient_checkpointing() | 546 | unet.enable_gradient_checkpointing() |
551 | text_encoder.gradient_checkpointing_enable() | 547 | text_encoder.gradient_checkpointing_enable() |
552 | 548 | ||
553 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
554 | |||
555 | # Convert the initializer_token, placeholder_token to ids | 549 | # Convert the initializer_token, placeholder_token to ids |
556 | initializer_token_ids = torch.stack([ | 550 | initializer_token_ids = torch.stack([ |
557 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 551 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -563,13 +557,17 @@ def main(): | |||
563 | 557 | ||
564 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 558 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
565 | 559 | ||
560 | print(f"Token ID mappings:") | ||
561 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
562 | print(f"- {token_id} {token}") | ||
563 | |||
566 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 564 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
567 | text_encoder.resize_token_embeddings(len(tokenizer)) | 565 | text_encoder.resize_token_embeddings(len(tokenizer)) |
568 | 566 | ||
569 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 567 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
570 | token_embeds = text_encoder.get_input_embeddings().weight.data | 568 | token_embeds = text_encoder.get_input_embeddings().weight.data |
571 | 569 | ||
572 | if args.resume_from: | 570 | if args.resume_from is not None: |
573 | resumepath = Path(args.resume_from).joinpath("checkpoints") | 571 | resumepath = Path(args.resume_from).joinpath("checkpoints") |
574 | 572 | ||
575 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 573 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
@@ -983,13 +981,13 @@ def main(): | |||
983 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 981 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
984 | min_val_loss = val_loss | 982 | min_val_loss = val_loss |
985 | 983 | ||
986 | if epoch % args.checkpoint_frequency == 0: | 984 | if (epoch + 1) % args.checkpoint_frequency == 0: |
987 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 985 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
988 | save_args(basepath, args, { | 986 | save_args(basepath, args, { |
989 | "global_step": global_step + global_step_offset | 987 | "global_step": global_step + global_step_offset |
990 | }) | 988 | }) |
991 | 989 | ||
992 | if epoch % args.sample_frequency == 0: | 990 | if (epoch + 1) % args.sample_frequency == 0: |
993 | checkpointer.save_samples( | 991 | checkpointer.save_samples( |
994 | global_step + global_step_offset, | 992 | global_step + global_step_offset, |
995 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 993 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |