summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py19
1 files changed, 8 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 49d4447..3dd0920 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -115,7 +115,7 @@ def parse_args():
115 parser.add_argument( 115 parser.add_argument(
116 "--resolution", 116 "--resolution",
117 type=int, 117 type=int,
118 default=512, 118 default=768,
119 help=( 119 help=(
120 "The resolution for input images, all the images in the train/validation dataset will be resized to this" 120 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
121 " resolution" 121 " resolution"
@@ -267,7 +267,7 @@ def parse_args():
267 parser.add_argument( 267 parser.add_argument(
268 "--sample_image_size", 268 "--sample_image_size",
269 type=int, 269 type=int,
270 default=512, 270 default=768,
271 help="Size of sample images", 271 help="Size of sample images",
272 ) 272 )
273 parser.add_argument( 273 parser.add_argument(
@@ -297,7 +297,7 @@ def parse_args():
297 parser.add_argument( 297 parser.add_argument(
298 "--sample_steps", 298 "--sample_steps",
299 type=int, 299 type=int,
300 default=25, 300 default=15,
301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
302 ) 302 )
303 parser.add_argument( 303 parser.add_argument(
@@ -459,7 +459,7 @@ class Checkpointer:
459 torch.cuda.empty_cache() 459 torch.cuda.empty_cache()
460 460
461 @torch.no_grad() 461 @torch.no_grad()
462 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 462 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
463 samples_path = Path(self.output_dir).joinpath("samples") 463 samples_path = Path(self.output_dir).joinpath("samples")
464 464
465 unwrapped_unet = self.accelerator.unwrap_model( 465 unwrapped_unet = self.accelerator.unwrap_model(
@@ -474,13 +474,14 @@ class Checkpointer:
474 scheduler=self.scheduler, 474 scheduler=self.scheduler,
475 ).to(self.accelerator.device) 475 ).to(self.accelerator.device)
476 pipeline.set_progress_bar_config(dynamic_ncols=True) 476 pipeline.set_progress_bar_config(dynamic_ncols=True)
477 pipeline.enable_vae_slicing()
477 478
478 train_data = self.datamodule.train_dataloader() 479 train_data = self.datamodule.train_dataloader()
479 val_data = self.datamodule.val_dataloader() 480 val_data = self.datamodule.val_dataloader()
480 481
481 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 482 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
482 stable_latents = torch.randn( 483 stable_latents = torch.randn(
483 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 484 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
484 device=pipeline.device, 485 device=pipeline.device,
485 generator=generator, 486 generator=generator,
486 ) 487 )
@@ -875,9 +876,7 @@ def main():
875 ) 876 )
876 877
877 if accelerator.is_main_process: 878 if accelerator.is_main_process:
878 checkpointer.save_samples( 879 checkpointer.save_samples(0, args.sample_steps)
879 0,
880 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
881 880
882 local_progress_bar = tqdm( 881 local_progress_bar = tqdm(
883 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 882 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
@@ -1089,9 +1088,7 @@ def main():
1089 max_acc_val = avg_acc_val 1088 max_acc_val = avg_acc_val
1090 1089
1091 if sample_checkpoint and accelerator.is_main_process: 1090 if sample_checkpoint and accelerator.is_main_process:
1092 checkpointer.save_samples( 1091 checkpointer.save_samples(global_step, args.sample_steps)
1093 global_step,
1094 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1095 1092
1096 # Create the pipeline using using the trained modules and save it. 1093 # Create the pipeline using using the trained modules and save it.
1097 if accelerator.is_main_process: 1094 if accelerator.is_main_process: