summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py59
1 files changed, 50 insertions, 9 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 7996bc2..b5ec2fc 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -58,6 +58,12 @@ def parse_args():
58 parser.add_argument( 58 parser.add_argument(
59 "--placeholder_token", 59 "--placeholder_token",
60 type=str, 60 type=str,
61 default="<*>",
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--class_identifier",
66 type=str,
61 default=None, 67 default=None,
62 help="A token to use as a placeholder for the concept.", 68 help="A token to use as a placeholder for the concept.",
63 ) 69 )
@@ -71,7 +77,7 @@ def parse_args():
71 "--num_class_images", 77 "--num_class_images",
72 type=int, 78 type=int,
73 default=400, 79 default=400,
74 help="How many class images to generate per training image." 80 help="How many class images to generate."
75 ) 81 )
76 parser.add_argument( 82 parser.add_argument(
77 "--repeats", 83 "--repeats",
@@ -112,7 +118,7 @@ def parse_args():
112 parser.add_argument( 118 parser.add_argument(
113 "--max_train_steps", 119 "--max_train_steps",
114 type=int, 120 type=int,
115 default=1600, 121 default=2300,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 123 )
118 parser.add_argument( 124 parser.add_argument(
@@ -135,7 +141,7 @@ def parse_args():
135 parser.add_argument( 141 parser.add_argument(
136 "--learning_rate_text", 142 "--learning_rate_text",
137 type=float, 143 type=float,
138 default=5e-4, 144 default=5e-6,
139 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 146 )
141 parser.add_argument( 147 parser.add_argument(
@@ -222,6 +228,12 @@ def parse_args():
222 ), 228 ),
223 ) 229 )
224 parser.add_argument( 230 parser.add_argument(
231 "--checkpoint_frequency",
232 type=int,
233 default=500,
234 help="How often to save a checkpoint and sample image",
235 )
236 parser.add_argument(
225 "--sample_frequency", 237 "--sample_frequency",
226 type=int, 238 type=int,
227 default=100, 239 default=100,
@@ -352,7 +364,26 @@ class Checkpointer:
352 self.sample_batch_size = sample_batch_size 364 self.sample_batch_size = sample_batch_size
353 365
354 @torch.no_grad() 366 @torch.no_grad()
355 def checkpoint(self): 367 def checkpoint(self, step, postfix):
368 print("Saving checkpoint for step %d..." % step)
369
370 checkpoints_path = self.output_dir.joinpath("checkpoints")
371 checkpoints_path.mkdir(parents=True, exist_ok=True)
372
373 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
374
375 # Save a checkpoint
376 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
377 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
378
379 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
380 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
381
382 del unwrapped
383 del learned_embeds
384
385 @torch.no_grad()
386 def save_model(self):
356 print("Saving model...") 387 print("Saving model...")
357 388
358 unwrapped_unet = self.accelerator.unwrap_model( 389 unwrapped_unet = self.accelerator.unwrap_model(
@@ -612,7 +643,7 @@ def main():
612 batch_size=args.train_batch_size, 643 batch_size=args.train_batch_size,
613 tokenizer=tokenizer, 644 tokenizer=tokenizer,
614 instance_identifier=args.placeholder_token, 645 instance_identifier=args.placeholder_token,
615 class_identifier=args.initializer_token, 646 class_identifier=args.class_identifier,
616 class_subdir="cls", 647 class_subdir="cls",
617 num_class_images=args.num_class_images, 648 num_class_images=args.num_class_images,
618 size=args.resolution, 649 size=args.resolution,
@@ -648,7 +679,7 @@ def main():
648 with torch.inference_mode(): 679 with torch.inference_mode():
649 for batch in batched_data: 680 for batch in batched_data:
650 image_name = [p.class_image_path for p in batch] 681 image_name = [p.class_image_path for p in batch]
651 prompt = [p.prompt.format(args.initializer_token) for p in batch] 682 prompt = [p.prompt.format(args.class_identifier) for p in batch]
652 nprompt = [p.nprompt for p in batch] 683 nprompt = [p.nprompt for p in batch]
653 684
654 images = pipeline( 685 images = pipeline(
@@ -842,6 +873,12 @@ def main():
842 if global_step % args.sample_frequency == 0: 873 if global_step % args.sample_frequency == 0:
843 sample_checkpoint = True 874 sample_checkpoint = True
844 875
876 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
877 local_progress_bar.clear()
878 global_progress_bar.clear()
879
880 checkpointer.checkpoint(global_step + global_step_offset, "training")
881
845 logs = { 882 logs = {
846 "train/loss": loss, 883 "train/loss": loss,
847 "lr/unet": lr_scheduler.get_last_lr()[0], 884 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -903,6 +940,9 @@ def main():
903 global_progress_bar.clear() 940 global_progress_bar.clear()
904 941
905 if min_val_loss > val_loss: 942 if min_val_loss > val_loss:
943 accelerator.print(
944 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
945 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
906 min_val_loss = val_loss 946 min_val_loss = val_loss
907 947
908 if sample_checkpoint and accelerator.is_main_process: 948 if sample_checkpoint and accelerator.is_main_process:
@@ -913,14 +953,15 @@ def main():
913 # Create the pipeline using using the trained modules and save it. 953 # Create the pipeline using using the trained modules and save it.
914 if accelerator.is_main_process: 954 if accelerator.is_main_process:
915 print("Finished! Saving final checkpoint and resume state.") 955 print("Finished! Saving final checkpoint and resume state.")
916 checkpointer.checkpoint() 956 checkpointer.checkpoint(global_step + global_step_offset, "end")
917 957 checkpointer.save_model()
918 accelerator.end_training() 958 accelerator.end_training()
919 959
920 except KeyboardInterrupt: 960 except KeyboardInterrupt:
921 if accelerator.is_main_process: 961 if accelerator.is_main_process:
922 print("Interrupted, saving checkpoint and resume state...") 962 print("Interrupted, saving checkpoint and resume state...")
923 checkpointer.checkpoint() 963 checkpointer.checkpoint(global_step + global_step_offset, "end")
964 checkpointer.save_model()
924 accelerator.end_training() 965 accelerator.end_training()
925 quit() 966 quit()
926 967