diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/dreambooth.py b/dreambooth.py index 24e6091..a26bea7 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -68,7 +68,7 @@ def parse_args(): | |||
| 68 | parser.add_argument( | 68 | parser.add_argument( |
| 69 | "--num_class_images", | 69 | "--num_class_images", |
| 70 | type=int, | 70 | type=int, |
| 71 | default=2, | 71 | default=4, |
| 72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
| 73 | ) | 73 | ) |
| 74 | parser.add_argument( | 74 | parser.add_argument( |
| @@ -106,7 +106,8 @@ def parse_args(): | |||
| 106 | parser.add_argument( | 106 | parser.add_argument( |
| 107 | "--num_train_epochs", | 107 | "--num_train_epochs", |
| 108 | type=int, | 108 | type=int, |
| 109 | default=100) | 109 | default=100 |
| 110 | ) | ||
| 110 | parser.add_argument( | 111 | parser.add_argument( |
| 111 | "--max_train_steps", | 112 | "--max_train_steps", |
| 112 | type=int, | 113 | type=int, |
| @@ -293,7 +294,7 @@ class Checkpointer: | |||
| 293 | unet, | 294 | unet, |
| 294 | tokenizer, | 295 | tokenizer, |
| 295 | text_encoder, | 296 | text_encoder, |
| 296 | output_dir, | 297 | output_dir: Path, |
| 297 | instance_identifier, | 298 | instance_identifier, |
| 298 | sample_image_size, | 299 | sample_image_size, |
| 299 | sample_batches, | 300 | sample_batches, |
| @@ -321,14 +322,14 @@ class Checkpointer: | |||
| 321 | pipeline = VlpnStableDiffusion( | 322 | pipeline = VlpnStableDiffusion( |
| 322 | text_encoder=self.text_encoder, | 323 | text_encoder=self.text_encoder, |
| 323 | vae=self.vae, | 324 | vae=self.vae, |
| 324 | unet=self.accelerator.unwrap_model(self.unet), | 325 | unet=unwrapped, |
| 325 | tokenizer=self.tokenizer, | 326 | tokenizer=self.tokenizer, |
| 326 | scheduler=PNDMScheduler( | 327 | scheduler=PNDMScheduler( |
| 327 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 328 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 328 | ), | 329 | ), |
| 329 | ) | 330 | ) |
| 330 | pipeline.enable_attention_slicing() | 331 | pipeline.enable_attention_slicing() |
| 331 | pipeline.save_pretrained(f"{self.output_dir}/model") | 332 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 332 | 333 | ||
| 333 | del unwrapped | 334 | del unwrapped |
| 334 | del pipeline | 335 | del pipeline |
| @@ -524,7 +525,7 @@ def main(): | |||
| 524 | tokenizer=tokenizer, | 525 | tokenizer=tokenizer, |
| 525 | instance_identifier=args.instance_identifier, | 526 | instance_identifier=args.instance_identifier, |
| 526 | class_identifier=args.class_identifier, | 527 | class_identifier=args.class_identifier, |
| 527 | class_subdir="db_cls", | 528 | class_subdir="cls", |
| 528 | num_class_images=args.num_class_images, | 529 | num_class_images=args.num_class_images, |
| 529 | size=args.resolution, | 530 | size=args.resolution, |
| 530 | repeats=args.repeats, | 531 | repeats=args.repeats, |
| @@ -580,21 +581,6 @@ def main(): | |||
| 580 | train_dataloader = datamodule.train_dataloader() | 581 | train_dataloader = datamodule.train_dataloader() |
| 581 | val_dataloader = datamodule.val_dataloader() | 582 | val_dataloader = datamodule.val_dataloader() |
| 582 | 583 | ||
| 583 | checkpointer = Checkpointer( | ||
| 584 | datamodule=datamodule, | ||
| 585 | accelerator=accelerator, | ||
| 586 | vae=vae, | ||
| 587 | unet=unet, | ||
| 588 | tokenizer=tokenizer, | ||
| 589 | text_encoder=text_encoder, | ||
| 590 | output_dir=basepath, | ||
| 591 | instance_identifier=args.instance_identifier, | ||
| 592 | sample_image_size=args.sample_image_size, | ||
| 593 | sample_batch_size=args.sample_batch_size, | ||
| 594 | sample_batches=args.sample_batches, | ||
| 595 | seed=args.seed | ||
| 596 | ) | ||
| 597 | |||
| 598 | # Scheduler and math around the number of training steps. | 584 | # Scheduler and math around the number of training steps. |
| 599 | overrode_max_train_steps = False | 585 | overrode_max_train_steps = False |
| 600 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 586 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| @@ -613,7 +599,7 @@ def main(): | |||
| 613 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 599 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 614 | ) | 600 | ) |
| 615 | 601 | ||
| 616 | # Move vae and unet to device | 602 | # Move text_encoder and vae to device |
| 617 | text_encoder.to(accelerator.device) | 603 | text_encoder.to(accelerator.device) |
| 618 | vae.to(accelerator.device) | 604 | vae.to(accelerator.device) |
| 619 | 605 | ||
| @@ -649,6 +635,21 @@ def main(): | |||
| 649 | global_step = 0 | 635 | global_step = 0 |
| 650 | min_val_loss = np.inf | 636 | min_val_loss = np.inf |
| 651 | 637 | ||
| 638 | checkpointer = Checkpointer( | ||
| 639 | datamodule=datamodule, | ||
| 640 | accelerator=accelerator, | ||
| 641 | vae=vae, | ||
| 642 | unet=unet, | ||
| 643 | tokenizer=tokenizer, | ||
| 644 | text_encoder=text_encoder, | ||
| 645 | output_dir=basepath, | ||
| 646 | instance_identifier=args.instance_identifier, | ||
| 647 | sample_image_size=args.sample_image_size, | ||
| 648 | sample_batch_size=args.sample_batch_size, | ||
| 649 | sample_batches=args.sample_batches, | ||
| 650 | seed=args.seed | ||
| 651 | ) | ||
| 652 | |||
| 652 | if accelerator.is_main_process: | 653 | if accelerator.is_main_process: |
| 653 | checkpointer.save_samples( | 654 | checkpointer.save_samples( |
| 654 | 0, | 655 | 0, |
