diff options
author | Volpeon <git@volpeon.ink> | 2022-10-07 14:54:44 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-07 14:54:44 +0200 |
commit | 2af0d47b44fe02269b1378f7691d258d35544bb3 (patch) | |
tree | cb3250de69e17ad3536e0f548805b7a087a041f2 /dreambooth.py | |
parent | Training: Create multiple class images per training image (diff) | |
download | textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.gz textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.tar.bz2 textual-inversion-diff-2af0d47b44fe02269b1378f7691d258d35544bb3.zip |
Fix small details
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/dreambooth.py b/dreambooth.py index 24e6091..a26bea7 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -68,7 +68,7 @@ def parse_args(): | |||
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | 69 | "--num_class_images", |
70 | type=int, | 70 | type=int, |
71 | default=2, | 71 | default=4, |
72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
73 | ) | 73 | ) |
74 | parser.add_argument( | 74 | parser.add_argument( |
@@ -106,7 +106,8 @@ def parse_args(): | |||
106 | parser.add_argument( | 106 | parser.add_argument( |
107 | "--num_train_epochs", | 107 | "--num_train_epochs", |
108 | type=int, | 108 | type=int, |
109 | default=100) | 109 | default=100 |
110 | ) | ||
110 | parser.add_argument( | 111 | parser.add_argument( |
111 | "--max_train_steps", | 112 | "--max_train_steps", |
112 | type=int, | 113 | type=int, |
@@ -293,7 +294,7 @@ class Checkpointer: | |||
293 | unet, | 294 | unet, |
294 | tokenizer, | 295 | tokenizer, |
295 | text_encoder, | 296 | text_encoder, |
296 | output_dir, | 297 | output_dir: Path, |
297 | instance_identifier, | 298 | instance_identifier, |
298 | sample_image_size, | 299 | sample_image_size, |
299 | sample_batches, | 300 | sample_batches, |
@@ -321,14 +322,14 @@ class Checkpointer: | |||
321 | pipeline = VlpnStableDiffusion( | 322 | pipeline = VlpnStableDiffusion( |
322 | text_encoder=self.text_encoder, | 323 | text_encoder=self.text_encoder, |
323 | vae=self.vae, | 324 | vae=self.vae, |
324 | unet=self.accelerator.unwrap_model(self.unet), | 325 | unet=unwrapped, |
325 | tokenizer=self.tokenizer, | 326 | tokenizer=self.tokenizer, |
326 | scheduler=PNDMScheduler( | 327 | scheduler=PNDMScheduler( |
327 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 328 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
328 | ), | 329 | ), |
329 | ) | 330 | ) |
330 | pipeline.enable_attention_slicing() | 331 | pipeline.enable_attention_slicing() |
331 | pipeline.save_pretrained(f"{self.output_dir}/model") | 332 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
332 | 333 | ||
333 | del unwrapped | 334 | del unwrapped |
334 | del pipeline | 335 | del pipeline |
@@ -524,7 +525,7 @@ def main(): | |||
524 | tokenizer=tokenizer, | 525 | tokenizer=tokenizer, |
525 | instance_identifier=args.instance_identifier, | 526 | instance_identifier=args.instance_identifier, |
526 | class_identifier=args.class_identifier, | 527 | class_identifier=args.class_identifier, |
527 | class_subdir="db_cls", | 528 | class_subdir="cls", |
528 | num_class_images=args.num_class_images, | 529 | num_class_images=args.num_class_images, |
529 | size=args.resolution, | 530 | size=args.resolution, |
530 | repeats=args.repeats, | 531 | repeats=args.repeats, |
@@ -580,21 +581,6 @@ def main(): | |||
580 | train_dataloader = datamodule.train_dataloader() | 581 | train_dataloader = datamodule.train_dataloader() |
581 | val_dataloader = datamodule.val_dataloader() | 582 | val_dataloader = datamodule.val_dataloader() |
582 | 583 | ||
583 | checkpointer = Checkpointer( | ||
584 | datamodule=datamodule, | ||
585 | accelerator=accelerator, | ||
586 | vae=vae, | ||
587 | unet=unet, | ||
588 | tokenizer=tokenizer, | ||
589 | text_encoder=text_encoder, | ||
590 | output_dir=basepath, | ||
591 | instance_identifier=args.instance_identifier, | ||
592 | sample_image_size=args.sample_image_size, | ||
593 | sample_batch_size=args.sample_batch_size, | ||
594 | sample_batches=args.sample_batches, | ||
595 | seed=args.seed | ||
596 | ) | ||
597 | |||
598 | # Scheduler and math around the number of training steps. | 584 | # Scheduler and math around the number of training steps. |
599 | overrode_max_train_steps = False | 585 | overrode_max_train_steps = False |
600 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 586 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
@@ -613,7 +599,7 @@ def main(): | |||
613 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 599 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
614 | ) | 600 | ) |
615 | 601 | ||
616 | # Move vae and unet to device | 602 | # Move text_encoder and vae to device |
617 | text_encoder.to(accelerator.device) | 603 | text_encoder.to(accelerator.device) |
618 | vae.to(accelerator.device) | 604 | vae.to(accelerator.device) |
619 | 605 | ||
@@ -649,6 +635,21 @@ def main(): | |||
649 | global_step = 0 | 635 | global_step = 0 |
650 | min_val_loss = np.inf | 636 | min_val_loss = np.inf |
651 | 637 | ||
638 | checkpointer = Checkpointer( | ||
639 | datamodule=datamodule, | ||
640 | accelerator=accelerator, | ||
641 | vae=vae, | ||
642 | unet=unet, | ||
643 | tokenizer=tokenizer, | ||
644 | text_encoder=text_encoder, | ||
645 | output_dir=basepath, | ||
646 | instance_identifier=args.instance_identifier, | ||
647 | sample_image_size=args.sample_image_size, | ||
648 | sample_batch_size=args.sample_batch_size, | ||
649 | sample_batches=args.sample_batches, | ||
650 | seed=args.seed | ||
651 | ) | ||
652 | |||
652 | if accelerator.is_main_process: | 653 | if accelerator.is_main_process: |
653 | checkpointer.save_samples( | 654 | checkpointer.save_samples( |
654 | 0, | 655 | 0, |