summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 72c56cd..5c26f12 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -354,6 +354,8 @@ class Checkpointer:
354 text_encoder, 354 text_encoder,
355 output_dir: Path, 355 output_dir: Path,
356 instance_identifier, 356 instance_identifier,
357 placeholder_token,
358 placeholder_token_id,
357 sample_image_size, 359 sample_image_size,
358 sample_batches, 360 sample_batches,
359 sample_batch_size, 361 sample_batch_size,
@@ -368,12 +370,36 @@ class Checkpointer:
368 self.text_encoder = text_encoder 370 self.text_encoder = text_encoder
369 self.output_dir = output_dir 371 self.output_dir = output_dir
370 self.instance_identifier = instance_identifier 372 self.instance_identifier = instance_identifier
373 self.placeholder_token = placeholder_token
374 self.placeholder_token_id = placeholder_token_id
371 self.sample_image_size = sample_image_size 375 self.sample_image_size = sample_image_size
372 self.seed = seed or torch.random.seed() 376 self.seed = seed or torch.random.seed()
373 self.sample_batches = sample_batches 377 self.sample_batches = sample_batches
374 self.sample_batch_size = sample_batch_size 378 self.sample_batch_size = sample_batch_size
375 379
376 @torch.no_grad() 380 @torch.no_grad()
381 def save_embedding(self, step, postfix):
382 if self.placeholder_token_id is None:
383 return
384
385 print("Saving checkpoint for step %d..." % step)
386
387 checkpoints_path = self.output_dir.joinpath("checkpoints")
388 checkpoints_path.mkdir(parents=True, exist_ok=True)
389
390 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
391
392 # Save a checkpoint
393 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
394 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
395
396 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
397 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
398
399 del unwrapped
400 del learned_embeds
401
402 @torch.no_grad()
377 def save_model(self): 403 def save_model(self):
378 print("Saving model...") 404 print("Saving model...")
379 405
@@ -567,6 +593,8 @@ def main():
567 text_encoder.text_model.final_layer_norm.parameters(), 593 text_encoder.text_model.final_layer_norm.parameters(),
568 text_encoder.text_model.embeddings.position_embedding.parameters(), 594 text_encoder.text_model.embeddings.position_embedding.parameters(),
569 )) 595 ))
596 else:
597 placeholder_token_id = None
570 598
571 prompt_processor = PromptProcessor(tokenizer, text_encoder) 599 prompt_processor = PromptProcessor(tokenizer, text_encoder)
572 600
@@ -785,6 +813,8 @@ def main():
785 text_encoder=text_encoder, 813 text_encoder=text_encoder,
786 output_dir=basepath, 814 output_dir=basepath,
787 instance_identifier=instance_identifier, 815 instance_identifier=instance_identifier,
816 placeholder_token=args.placeholder_token,
817 placeholder_token_id=placeholder_token_id,
788 sample_image_size=args.sample_image_size, 818 sample_image_size=args.sample_image_size,
789 sample_batch_size=args.sample_batch_size, 819 sample_batch_size=args.sample_batch_size,
790 sample_batches=args.sample_batches, 820 sample_batches=args.sample_batches,
@@ -902,6 +932,7 @@ def main():
902 global_step += 1 932 global_step += 1
903 933
904 if global_step % args.sample_frequency == 0: 934 if global_step % args.sample_frequency == 0:
935 checkpointer.save_embedding(global_step, "training")
905 sample_checkpoint = True 936 sample_checkpoint = True
906 937
907 logs = { 938 logs = {
@@ -968,6 +999,7 @@ def main():
968 if min_val_loss > val_loss: 999 if min_val_loss > val_loss:
969 accelerator.print( 1000 accelerator.print(
970 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 1001 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
1002 checkpointer.save_embedding(global_step, "milestone")
971 min_val_loss = val_loss 1003 min_val_loss = val_loss
972 1004
973 if sample_checkpoint and accelerator.is_main_process: 1005 if sample_checkpoint and accelerator.is_main_process:
@@ -978,6 +1010,7 @@ def main():
978 # Create the pipeline using using the trained modules and save it. 1010 # Create the pipeline using using the trained modules and save it.
979 if accelerator.is_main_process: 1011 if accelerator.is_main_process:
980 print("Finished! Saving final checkpoint and resume state.") 1012 print("Finished! Saving final checkpoint and resume state.")
1013 checkpointer.save_embedding(global_step, "end")
981 checkpointer.save_model() 1014 checkpointer.save_model()
982 1015
983 accelerator.end_training() 1016 accelerator.end_training()
@@ -985,6 +1018,7 @@ def main():
985 except KeyboardInterrupt: 1018 except KeyboardInterrupt:
986 if accelerator.is_main_process: 1019 if accelerator.is_main_process:
987 print("Interrupted, saving checkpoint and resume state...") 1020 print("Interrupted, saving checkpoint and resume state...")
1021 checkpointer.save_embedding(global_step, "end")
988 checkpointer.save_model() 1022 checkpointer.save_model()
989 accelerator.end_training() 1023 accelerator.end_training()
990 quit() 1024 quit()