diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..5c26f12 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -354,6 +354,8 @@ class Checkpointer: | |||
| 354 | text_encoder, | 354 | text_encoder, |
| 355 | output_dir: Path, | 355 | output_dir: Path, |
| 356 | instance_identifier, | 356 | instance_identifier, |
| 357 | placeholder_token, | ||
| 358 | placeholder_token_id, | ||
| 357 | sample_image_size, | 359 | sample_image_size, |
| 358 | sample_batches, | 360 | sample_batches, |
| 359 | sample_batch_size, | 361 | sample_batch_size, |
| @@ -368,12 +370,36 @@ class Checkpointer: | |||
| 368 | self.text_encoder = text_encoder | 370 | self.text_encoder = text_encoder |
| 369 | self.output_dir = output_dir | 371 | self.output_dir = output_dir |
| 370 | self.instance_identifier = instance_identifier | 372 | self.instance_identifier = instance_identifier |
| 373 | self.placeholder_token = placeholder_token | ||
| 374 | self.placeholder_token_id = placeholder_token_id | ||
| 371 | self.sample_image_size = sample_image_size | 375 | self.sample_image_size = sample_image_size |
| 372 | self.seed = seed or torch.random.seed() | 376 | self.seed = seed or torch.random.seed() |
| 373 | self.sample_batches = sample_batches | 377 | self.sample_batches = sample_batches |
| 374 | self.sample_batch_size = sample_batch_size | 378 | self.sample_batch_size = sample_batch_size |
| 375 | 379 | ||
| 376 | @torch.no_grad() | 380 | @torch.no_grad() |
| 381 | def save_embedding(self, step, postfix): | ||
| 382 | if self.placeholder_token_id is None: | ||
| 383 | return | ||
| 384 | |||
| 385 | print("Saving checkpoint for step %d..." % step) | ||
| 386 | |||
| 387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 388 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 389 | |||
| 390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
| 391 | |||
| 392 | # Save a checkpoint | ||
| 393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 395 | |||
| 396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
| 397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 398 | |||
| 399 | del unwrapped | ||
| 400 | del learned_embeds | ||
| 401 | |||
| 402 | @torch.no_grad() | ||
| 377 | def save_model(self): | 403 | def save_model(self): |
| 378 | print("Saving model...") | 404 | print("Saving model...") |
| 379 | 405 | ||
| @@ -567,6 +593,8 @@ def main(): | |||
| 567 | text_encoder.text_model.final_layer_norm.parameters(), | 593 | text_encoder.text_model.final_layer_norm.parameters(), |
| 568 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 594 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 569 | )) | 595 | )) |
| 596 | else: | ||
| 597 | placeholder_token_id = None | ||
| 570 | 598 | ||
| 571 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 572 | 600 | ||
| @@ -785,6 +813,8 @@ def main(): | |||
| 785 | text_encoder=text_encoder, | 813 | text_encoder=text_encoder, |
| 786 | output_dir=basepath, | 814 | output_dir=basepath, |
| 787 | instance_identifier=instance_identifier, | 815 | instance_identifier=instance_identifier, |
| 816 | placeholder_token=args.placeholder_token, | ||
| 817 | placeholder_token_id=placeholder_token_id, | ||
| 788 | sample_image_size=args.sample_image_size, | 818 | sample_image_size=args.sample_image_size, |
| 789 | sample_batch_size=args.sample_batch_size, | 819 | sample_batch_size=args.sample_batch_size, |
| 790 | sample_batches=args.sample_batches, | 820 | sample_batches=args.sample_batches, |
| @@ -902,6 +932,7 @@ def main(): | |||
| 902 | global_step += 1 | 932 | global_step += 1 |
| 903 | 933 | ||
| 904 | if global_step % args.sample_frequency == 0: | 934 | if global_step % args.sample_frequency == 0: |
| 935 | checkpointer.save_embedding(global_step, "training") | ||
| 905 | sample_checkpoint = True | 936 | sample_checkpoint = True |
| 906 | 937 | ||
| 907 | logs = { | 938 | logs = { |
| @@ -968,6 +999,7 @@ def main(): | |||
| 968 | if min_val_loss > val_loss: | 999 | if min_val_loss > val_loss: |
| 969 | accelerator.print( | 1000 | accelerator.print( |
| 970 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1001 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 1002 | checkpointer.save_embedding(global_step, "milestone") | ||
| 971 | min_val_loss = val_loss | 1003 | min_val_loss = val_loss |
| 972 | 1004 | ||
| 973 | if sample_checkpoint and accelerator.is_main_process: | 1005 | if sample_checkpoint and accelerator.is_main_process: |
| @@ -978,6 +1010,7 @@ def main(): | |||
| 978 | # Create the pipeline using using the trained modules and save it. | 1010 | # Create the pipeline using using the trained modules and save it. |
| 979 | if accelerator.is_main_process: | 1011 | if accelerator.is_main_process: |
| 980 | print("Finished! Saving final checkpoint and resume state.") | 1012 | print("Finished! Saving final checkpoint and resume state.") |
| 1013 | checkpointer.save_embedding(global_step, "end") | ||
| 981 | checkpointer.save_model() | 1014 | checkpointer.save_model() |
| 982 | 1015 | ||
| 983 | accelerator.end_training() | 1016 | accelerator.end_training() |
| @@ -985,6 +1018,7 @@ def main(): | |||
| 985 | except KeyboardInterrupt: | 1018 | except KeyboardInterrupt: |
| 986 | if accelerator.is_main_process: | 1019 | if accelerator.is_main_process: |
| 987 | print("Interrupted, saving checkpoint and resume state...") | 1020 | print("Interrupted, saving checkpoint and resume state...") |
| 1021 | checkpointer.save_embedding(global_step, "end") | ||
| 988 | checkpointer.save_model() | 1022 | checkpointer.save_model() |
| 989 | accelerator.end_training() | 1023 | accelerator.end_training() |
| 990 | quit() | 1024 | quit() |
