diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 19 |
1 files changed, 8 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -115,7 +115,7 @@ def parse_args(): | |||
| 115 | parser.add_argument( | 115 | parser.add_argument( |
| 116 | "--resolution", | 116 | "--resolution", |
| 117 | type=int, | 117 | type=int, |
| 118 | default=512, | 118 | default=768, |
| 119 | help=( | 119 | help=( |
| 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
| 121 | " resolution" | 121 | " resolution" |
| @@ -267,7 +267,7 @@ def parse_args(): | |||
| 267 | parser.add_argument( | 267 | parser.add_argument( |
| 268 | "--sample_image_size", | 268 | "--sample_image_size", |
| 269 | type=int, | 269 | type=int, |
| 270 | default=512, | 270 | default=768, |
| 271 | help="Size of sample images", | 271 | help="Size of sample images", |
| 272 | ) | 272 | ) |
| 273 | parser.add_argument( | 273 | parser.add_argument( |
| @@ -297,7 +297,7 @@ def parse_args(): | |||
| 297 | parser.add_argument( | 297 | parser.add_argument( |
| 298 | "--sample_steps", | 298 | "--sample_steps", |
| 299 | type=int, | 299 | type=int, |
| 300 | default=25, | 300 | default=15, |
| 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 302 | ) | 302 | ) |
| 303 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -459,7 +459,7 @@ class Checkpointer: | |||
| 459 | torch.cuda.empty_cache() | 459 | torch.cuda.empty_cache() |
| 460 | 460 | ||
| 461 | @torch.no_grad() | 461 | @torch.no_grad() |
| 462 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 462 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 463 | samples_path = Path(self.output_dir).joinpath("samples") | 463 | samples_path = Path(self.output_dir).joinpath("samples") |
| 464 | 464 | ||
| 465 | unwrapped_unet = self.accelerator.unwrap_model( | 465 | unwrapped_unet = self.accelerator.unwrap_model( |
| @@ -474,13 +474,14 @@ class Checkpointer: | |||
| 474 | scheduler=self.scheduler, | 474 | scheduler=self.scheduler, |
| 475 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
| 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 477 | pipeline.enable_vae_slicing() | ||
| 477 | 478 | ||
| 478 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
| 479 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
| 480 | 481 | ||
| 481 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 482 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 482 | stable_latents = torch.randn( | 483 | stable_latents = torch.randn( |
| 483 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 484 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), |
| 484 | device=pipeline.device, | 485 | device=pipeline.device, |
| 485 | generator=generator, | 486 | generator=generator, |
| 486 | ) | 487 | ) |
| @@ -875,9 +876,7 @@ def main(): | |||
| 875 | ) | 876 | ) |
| 876 | 877 | ||
| 877 | if accelerator.is_main_process: | 878 | if accelerator.is_main_process: |
| 878 | checkpointer.save_samples( | 879 | checkpointer.save_samples(0, args.sample_steps) |
| 879 | 0, | ||
| 880 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 881 | 880 | ||
| 882 | local_progress_bar = tqdm( | 881 | local_progress_bar = tqdm( |
| 883 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 882 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -1089,9 +1088,7 @@ def main(): | |||
| 1089 | max_acc_val = avg_acc_val | 1088 | max_acc_val = avg_acc_val |
| 1090 | 1089 | ||
| 1091 | if sample_checkpoint and accelerator.is_main_process: | 1090 | if sample_checkpoint and accelerator.is_main_process: |
| 1092 | checkpointer.save_samples( | 1091 | checkpointer.save_samples(global_step, args.sample_steps) |
| 1093 | global_step, | ||
| 1094 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 1095 | 1092 | ||
| 1096 | # Create the pipeline using using the trained modules and save it. | 1093 | # Create the pipeline using using the trained modules and save it. |
| 1097 | if accelerator.is_main_process: | 1094 | if accelerator.is_main_process: |
