diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 7996bc2..b5ec2fc 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -58,6 +58,12 @@ def parse_args(): | |||
| 58 | parser.add_argument( | 58 | parser.add_argument( |
| 59 | "--placeholder_token", | 59 | "--placeholder_token", |
| 60 | type=str, | 60 | type=str, |
| 61 | default="<*>", | ||
| 62 | help="A token to use as a placeholder for the concept.", | ||
| 63 | ) | ||
| 64 | parser.add_argument( | ||
| 65 | "--class_identifier", | ||
| 66 | type=str, | ||
| 61 | default=None, | 67 | default=None, |
| 62 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
| 63 | ) | 69 | ) |
| @@ -71,7 +77,7 @@ def parse_args(): | |||
| 71 | "--num_class_images", | 77 | "--num_class_images", |
| 72 | type=int, | 78 | type=int, |
| 73 | default=400, | 79 | default=400, |
| 74 | help="How many class images to generate per training image." | 80 | help="How many class images to generate." |
| 75 | ) | 81 | ) |
| 76 | parser.add_argument( | 82 | parser.add_argument( |
| 77 | "--repeats", | 83 | "--repeats", |
| @@ -112,7 +118,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 118 | parser.add_argument( |
| 113 | "--max_train_steps", | 119 | "--max_train_steps", |
| 114 | type=int, | 120 | type=int, |
| 115 | default=1600, | 121 | default=2300, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 123 | ) |
| 118 | parser.add_argument( | 124 | parser.add_argument( |
| @@ -135,7 +141,7 @@ def parse_args(): | |||
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--learning_rate_text", | 142 | "--learning_rate_text", |
| 137 | type=float, | 143 | type=float, |
| 138 | default=5e-4, | 144 | default=5e-6, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 146 | ) |
| 141 | parser.add_argument( | 147 | parser.add_argument( |
| @@ -222,6 +228,12 @@ def parse_args(): | |||
| 222 | ), | 228 | ), |
| 223 | ) | 229 | ) |
| 224 | parser.add_argument( | 230 | parser.add_argument( |
| 231 | "--checkpoint_frequency", | ||
| 232 | type=int, | ||
| 233 | default=500, | ||
| 234 | help="How often to save a checkpoint and sample image", | ||
| 235 | ) | ||
| 236 | parser.add_argument( | ||
| 225 | "--sample_frequency", | 237 | "--sample_frequency", |
| 226 | type=int, | 238 | type=int, |
| 227 | default=100, | 239 | default=100, |
| @@ -352,7 +364,26 @@ class Checkpointer: | |||
| 352 | self.sample_batch_size = sample_batch_size | 364 | self.sample_batch_size = sample_batch_size |
| 353 | 365 | ||
| 354 | @torch.no_grad() | 366 | @torch.no_grad() |
| 355 | def checkpoint(self): | 367 | def checkpoint(self, step, postfix): |
| 368 | print("Saving checkpoint for step %d..." % step) | ||
| 369 | |||
| 370 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 371 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 372 | |||
| 373 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
| 374 | |||
| 375 | # Save a checkpoint | ||
| 376 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 377 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 378 | |||
| 379 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
| 380 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 381 | |||
| 382 | del unwrapped | ||
| 383 | del learned_embeds | ||
| 384 | |||
| 385 | @torch.no_grad() | ||
| 386 | def save_model(self): | ||
| 356 | print("Saving model...") | 387 | print("Saving model...") |
| 357 | 388 | ||
| 358 | unwrapped_unet = self.accelerator.unwrap_model( | 389 | unwrapped_unet = self.accelerator.unwrap_model( |
| @@ -612,7 +643,7 @@ def main(): | |||
| 612 | batch_size=args.train_batch_size, | 643 | batch_size=args.train_batch_size, |
| 613 | tokenizer=tokenizer, | 644 | tokenizer=tokenizer, |
| 614 | instance_identifier=args.placeholder_token, | 645 | instance_identifier=args.placeholder_token, |
| 615 | class_identifier=args.initializer_token, | 646 | class_identifier=args.class_identifier, |
| 616 | class_subdir="cls", | 647 | class_subdir="cls", |
| 617 | num_class_images=args.num_class_images, | 648 | num_class_images=args.num_class_images, |
| 618 | size=args.resolution, | 649 | size=args.resolution, |
| @@ -648,7 +679,7 @@ def main(): | |||
| 648 | with torch.inference_mode(): | 679 | with torch.inference_mode(): |
| 649 | for batch in batched_data: | 680 | for batch in batched_data: |
| 650 | image_name = [p.class_image_path for p in batch] | 681 | image_name = [p.class_image_path for p in batch] |
| 651 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 682 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
| 652 | nprompt = [p.nprompt for p in batch] | 683 | nprompt = [p.nprompt for p in batch] |
| 653 | 684 | ||
| 654 | images = pipeline( | 685 | images = pipeline( |
| @@ -842,6 +873,12 @@ def main(): | |||
| 842 | if global_step % args.sample_frequency == 0: | 873 | if global_step % args.sample_frequency == 0: |
| 843 | sample_checkpoint = True | 874 | sample_checkpoint = True |
| 844 | 875 | ||
| 876 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
| 877 | local_progress_bar.clear() | ||
| 878 | global_progress_bar.clear() | ||
| 879 | |||
| 880 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 881 | |||
| 845 | logs = { | 882 | logs = { |
| 846 | "train/loss": loss, | 883 | "train/loss": loss, |
| 847 | "lr/unet": lr_scheduler.get_last_lr()[0], | 884 | "lr/unet": lr_scheduler.get_last_lr()[0], |
| @@ -903,6 +940,9 @@ def main(): | |||
| 903 | global_progress_bar.clear() | 940 | global_progress_bar.clear() |
| 904 | 941 | ||
| 905 | if min_val_loss > val_loss: | 942 | if min_val_loss > val_loss: |
| 943 | accelerator.print( | ||
| 944 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 945 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
| 906 | min_val_loss = val_loss | 946 | min_val_loss = val_loss |
| 907 | 947 | ||
| 908 | if sample_checkpoint and accelerator.is_main_process: | 948 | if sample_checkpoint and accelerator.is_main_process: |
| @@ -913,14 +953,15 @@ def main(): | |||
| 913 | # Create the pipeline using using the trained modules and save it. | 953 | # Create the pipeline using using the trained modules and save it. |
| 914 | if accelerator.is_main_process: | 954 | if accelerator.is_main_process: |
| 915 | print("Finished! Saving final checkpoint and resume state.") | 955 | print("Finished! Saving final checkpoint and resume state.") |
| 916 | checkpointer.checkpoint() | 956 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 917 | 957 | checkpointer.save_model() | |
| 918 | accelerator.end_training() | 958 | accelerator.end_training() |
| 919 | 959 | ||
| 920 | except KeyboardInterrupt: | 960 | except KeyboardInterrupt: |
| 921 | if accelerator.is_main_process: | 961 | if accelerator.is_main_process: |
| 922 | print("Interrupted, saving checkpoint and resume state...") | 962 | print("Interrupted, saving checkpoint and resume state...") |
| 923 | checkpointer.checkpoint() | 963 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 964 | checkpointer.save_model() | ||
| 924 | accelerator.end_training() | 965 | accelerator.end_training() |
| 925 | quit() | 966 | quit() |
| 926 | 967 | ||
