diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -115,7 +115,7 @@ def parse_args(): | |||
115 | parser.add_argument( | 115 | parser.add_argument( |
116 | "--resolution", | 116 | "--resolution", |
117 | type=int, | 117 | type=int, |
118 | default=512, | 118 | default=768, |
119 | help=( | 119 | help=( |
120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
121 | " resolution" | 121 | " resolution" |
@@ -267,7 +267,7 @@ def parse_args(): | |||
267 | parser.add_argument( | 267 | parser.add_argument( |
268 | "--sample_image_size", | 268 | "--sample_image_size", |
269 | type=int, | 269 | type=int, |
270 | default=512, | 270 | default=768, |
271 | help="Size of sample images", | 271 | help="Size of sample images", |
272 | ) | 272 | ) |
273 | parser.add_argument( | 273 | parser.add_argument( |
@@ -297,7 +297,7 @@ def parse_args(): | |||
297 | parser.add_argument( | 297 | parser.add_argument( |
298 | "--sample_steps", | 298 | "--sample_steps", |
299 | type=int, | 299 | type=int, |
300 | default=25, | 300 | default=15, |
301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
302 | ) | 302 | ) |
303 | parser.add_argument( | 303 | parser.add_argument( |
@@ -459,7 +459,7 @@ class Checkpointer: | |||
459 | torch.cuda.empty_cache() | 459 | torch.cuda.empty_cache() |
460 | 460 | ||
461 | @torch.no_grad() | 461 | @torch.no_grad() |
462 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 462 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
463 | samples_path = Path(self.output_dir).joinpath("samples") | 463 | samples_path = Path(self.output_dir).joinpath("samples") |
464 | 464 | ||
465 | unwrapped_unet = self.accelerator.unwrap_model( | 465 | unwrapped_unet = self.accelerator.unwrap_model( |
@@ -474,13 +474,14 @@ class Checkpointer: | |||
474 | scheduler=self.scheduler, | 474 | scheduler=self.scheduler, |
475 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
477 | pipeline.enable_vae_slicing() | ||
477 | 478 | ||
478 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
479 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
480 | 481 | ||
481 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 482 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
482 | stable_latents = torch.randn( | 483 | stable_latents = torch.randn( |
483 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 484 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), |
484 | device=pipeline.device, | 485 | device=pipeline.device, |
485 | generator=generator, | 486 | generator=generator, |
486 | ) | 487 | ) |
@@ -875,9 +876,7 @@ def main(): | |||
875 | ) | 876 | ) |
876 | 877 | ||
877 | if accelerator.is_main_process: | 878 | if accelerator.is_main_process: |
878 | checkpointer.save_samples( | 879 | checkpointer.save_samples(0, args.sample_steps) |
879 | 0, | ||
880 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
881 | 880 | ||
882 | local_progress_bar = tqdm( | 881 | local_progress_bar = tqdm( |
883 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 882 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
@@ -1089,9 +1088,7 @@ def main(): | |||
1089 | max_acc_val = avg_acc_val | 1088 | max_acc_val = avg_acc_val |
1090 | 1089 | ||
1091 | if sample_checkpoint and accelerator.is_main_process: | 1090 | if sample_checkpoint and accelerator.is_main_process: |
1092 | checkpointer.save_samples( | 1091 | checkpointer.save_samples(global_step, args.sample_steps) |
1093 | global_step, | ||
1094 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
1095 | 1092 | ||
1096 | # Create the pipeline using using the trained modules and save it. | 1093 | # Create the pipeline using using the trained modules and save it. |
1097 | if accelerator.is_main_process: | 1094 | if accelerator.is_main_process: |