diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/dreambooth.py b/dreambooth.py index 72c56cd..5c26f12 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -354,6 +354,8 @@ class Checkpointer: | |||
354 | text_encoder, | 354 | text_encoder, |
355 | output_dir: Path, | 355 | output_dir: Path, |
356 | instance_identifier, | 356 | instance_identifier, |
357 | placeholder_token, | ||
358 | placeholder_token_id, | ||
357 | sample_image_size, | 359 | sample_image_size, |
358 | sample_batches, | 360 | sample_batches, |
359 | sample_batch_size, | 361 | sample_batch_size, |
@@ -368,12 +370,36 @@ class Checkpointer: | |||
368 | self.text_encoder = text_encoder | 370 | self.text_encoder = text_encoder |
369 | self.output_dir = output_dir | 371 | self.output_dir = output_dir |
370 | self.instance_identifier = instance_identifier | 372 | self.instance_identifier = instance_identifier |
373 | self.placeholder_token = placeholder_token | ||
374 | self.placeholder_token_id = placeholder_token_id | ||
371 | self.sample_image_size = sample_image_size | 375 | self.sample_image_size = sample_image_size |
372 | self.seed = seed or torch.random.seed() | 376 | self.seed = seed or torch.random.seed() |
373 | self.sample_batches = sample_batches | 377 | self.sample_batches = sample_batches |
374 | self.sample_batch_size = sample_batch_size | 378 | self.sample_batch_size = sample_batch_size |
375 | 379 | ||
376 | @torch.no_grad() | 380 | @torch.no_grad() |
381 | def save_embedding(self, step, postfix): | ||
382 | if self.placeholder_token_id is None: | ||
383 | return | ||
384 | |||
385 | print("Saving checkpoint for step %d..." % step) | ||
386 | |||
387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
388 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
389 | |||
390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
391 | |||
392 | # Save a checkpoint | ||
393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
395 | |||
396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
398 | |||
399 | del unwrapped | ||
400 | del learned_embeds | ||
401 | |||
402 | @torch.no_grad() | ||
377 | def save_model(self): | 403 | def save_model(self): |
378 | print("Saving model...") | 404 | print("Saving model...") |
379 | 405 | ||
@@ -567,6 +593,8 @@ def main(): | |||
567 | text_encoder.text_model.final_layer_norm.parameters(), | 593 | text_encoder.text_model.final_layer_norm.parameters(), |
568 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 594 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
569 | )) | 595 | )) |
596 | else: | ||
597 | placeholder_token_id = None | ||
570 | 598 | ||
571 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
572 | 600 | ||
@@ -785,6 +813,8 @@ def main(): | |||
785 | text_encoder=text_encoder, | 813 | text_encoder=text_encoder, |
786 | output_dir=basepath, | 814 | output_dir=basepath, |
787 | instance_identifier=instance_identifier, | 815 | instance_identifier=instance_identifier, |
816 | placeholder_token=args.placeholder_token, | ||
817 | placeholder_token_id=placeholder_token_id, | ||
788 | sample_image_size=args.sample_image_size, | 818 | sample_image_size=args.sample_image_size, |
789 | sample_batch_size=args.sample_batch_size, | 819 | sample_batch_size=args.sample_batch_size, |
790 | sample_batches=args.sample_batches, | 820 | sample_batches=args.sample_batches, |
@@ -902,6 +932,7 @@ def main(): | |||
902 | global_step += 1 | 932 | global_step += 1 |
903 | 933 | ||
904 | if global_step % args.sample_frequency == 0: | 934 | if global_step % args.sample_frequency == 0: |
935 | checkpointer.save_embedding(global_step, "training") | ||
905 | sample_checkpoint = True | 936 | sample_checkpoint = True |
906 | 937 | ||
907 | logs = { | 938 | logs = { |
@@ -968,6 +999,7 @@ def main(): | |||
968 | if min_val_loss > val_loss: | 999 | if min_val_loss > val_loss: |
969 | accelerator.print( | 1000 | accelerator.print( |
970 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 1001 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
1002 | checkpointer.save_embedding(global_step, "milestone") | ||
971 | min_val_loss = val_loss | 1003 | min_val_loss = val_loss |
972 | 1004 | ||
973 | if sample_checkpoint and accelerator.is_main_process: | 1005 | if sample_checkpoint and accelerator.is_main_process: |
@@ -978,6 +1010,7 @@ def main(): | |||
978 | # Create the pipeline using using the trained modules and save it. | 1010 | # Create the pipeline using using the trained modules and save it. |
979 | if accelerator.is_main_process: | 1011 | if accelerator.is_main_process: |
980 | print("Finished! Saving final checkpoint and resume state.") | 1012 | print("Finished! Saving final checkpoint and resume state.") |
1013 | checkpointer.save_embedding(global_step, "end") | ||
981 | checkpointer.save_model() | 1014 | checkpointer.save_model() |
982 | 1015 | ||
983 | accelerator.end_training() | 1016 | accelerator.end_training() |
@@ -985,6 +1018,7 @@ def main(): | |||
985 | except KeyboardInterrupt: | 1018 | except KeyboardInterrupt: |
986 | if accelerator.is_main_process: | 1019 | if accelerator.is_main_process: |
987 | print("Interrupted, saving checkpoint and resume state...") | 1020 | print("Interrupted, saving checkpoint and resume state...") |
1021 | checkpointer.save_embedding(global_step, "end") | ||
988 | checkpointer.save_model() | 1022 | checkpointer.save_model() |
989 | accelerator.end_training() | 1023 | accelerator.end_training() |
990 | quit() | 1024 | quit() |