summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-07 14:54:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-07 14:54:44 +0200
commit2af0d47b44fe02269b1378f7691d258d35544bb3 (patch)
treecb3250de69e17ad3536e0f548805b7a087a041f2 /dreambooth.py
parentTraining: Create multiple class images per training image (diff)
downloadtextual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.gz
textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.bz2
textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.zip
Fix small details
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py45
1 files changed, 23 insertions, 22 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 24e6091..a26bea7 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -68,7 +68,7 @@ def parse_args():
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images", 69 "--num_class_images",
70 type=int, 70 type=int,
71 default=2, 71 default=4,
72 help="How many class images to generate per training image." 72 help="How many class images to generate per training image."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
@@ -106,7 +106,8 @@ def parse_args():
106 parser.add_argument( 106 parser.add_argument(
107 "--num_train_epochs", 107 "--num_train_epochs",
108 type=int, 108 type=int,
109 default=100) 109 default=100
110 )
110 parser.add_argument( 111 parser.add_argument(
111 "--max_train_steps", 112 "--max_train_steps",
112 type=int, 113 type=int,
@@ -293,7 +294,7 @@ class Checkpointer:
293 unet, 294 unet,
294 tokenizer, 295 tokenizer,
295 text_encoder, 296 text_encoder,
296 output_dir, 297 output_dir: Path,
297 instance_identifier, 298 instance_identifier,
298 sample_image_size, 299 sample_image_size,
299 sample_batches, 300 sample_batches,
@@ -321,14 +322,14 @@ class Checkpointer:
321 pipeline = VlpnStableDiffusion( 322 pipeline = VlpnStableDiffusion(
322 text_encoder=self.text_encoder, 323 text_encoder=self.text_encoder,
323 vae=self.vae, 324 vae=self.vae,
324 unet=self.accelerator.unwrap_model(self.unet), 325 unet=unwrapped,
325 tokenizer=self.tokenizer, 326 tokenizer=self.tokenizer,
326 scheduler=PNDMScheduler( 327 scheduler=PNDMScheduler(
327 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 328 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
328 ), 329 ),
329 ) 330 )
330 pipeline.enable_attention_slicing() 331 pipeline.enable_attention_slicing()
331 pipeline.save_pretrained(f"{self.output_dir}/model") 332 pipeline.save_pretrained(self.output_dir.joinpath("model"))
332 333
333 del unwrapped 334 del unwrapped
334 del pipeline 335 del pipeline
@@ -524,7 +525,7 @@ def main():
524 tokenizer=tokenizer, 525 tokenizer=tokenizer,
525 instance_identifier=args.instance_identifier, 526 instance_identifier=args.instance_identifier,
526 class_identifier=args.class_identifier, 527 class_identifier=args.class_identifier,
527 class_subdir="db_cls", 528 class_subdir="cls",
528 num_class_images=args.num_class_images, 529 num_class_images=args.num_class_images,
529 size=args.resolution, 530 size=args.resolution,
530 repeats=args.repeats, 531 repeats=args.repeats,
@@ -580,21 +581,6 @@ def main():
580 train_dataloader = datamodule.train_dataloader() 581 train_dataloader = datamodule.train_dataloader()
581 val_dataloader = datamodule.val_dataloader() 582 val_dataloader = datamodule.val_dataloader()
582 583
583 checkpointer = Checkpointer(
584 datamodule=datamodule,
585 accelerator=accelerator,
586 vae=vae,
587 unet=unet,
588 tokenizer=tokenizer,
589 text_encoder=text_encoder,
590 output_dir=basepath,
591 instance_identifier=args.instance_identifier,
592 sample_image_size=args.sample_image_size,
593 sample_batch_size=args.sample_batch_size,
594 sample_batches=args.sample_batches,
595 seed=args.seed
596 )
597
598 # Scheduler and math around the number of training steps. 584 # Scheduler and math around the number of training steps.
599 overrode_max_train_steps = False 585 overrode_max_train_steps = False
600 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 586 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -613,7 +599,7 @@ def main():
613 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 599 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
614 ) 600 )
615 601
616 # Move vae and unet to device 602 # Move text_encoder and vae to device
617 text_encoder.to(accelerator.device) 603 text_encoder.to(accelerator.device)
618 vae.to(accelerator.device) 604 vae.to(accelerator.device)
619 605
@@ -649,6 +635,21 @@ def main():
649 global_step = 0 635 global_step = 0
650 min_val_loss = np.inf 636 min_val_loss = np.inf
651 637
638 checkpointer = Checkpointer(
639 datamodule=datamodule,
640 accelerator=accelerator,
641 vae=vae,
642 unet=unet,
643 tokenizer=tokenizer,
644 text_encoder=text_encoder,
645 output_dir=basepath,
646 instance_identifier=args.instance_identifier,
647 sample_image_size=args.sample_image_size,
648 sample_batch_size=args.sample_batch_size,
649 sample_batches=args.sample_batches,
650 seed=args.seed
651 )
652
652 if accelerator.is_main_process: 653 if accelerator.is_main_process:
653 checkpointer.save_samples( 654 checkpointer.save_samples(
654 0, 655 0,