From 329ad48b307e782b0e23fce80ae9087a4003e73d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 30 Nov 2022 14:02:35 +0100 Subject: Update --- dreambooth.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -115,7 +115,7 @@ def parse_args(): parser.add_argument( "--resolution", type=int, - default=512, + default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" @@ -267,7 +267,7 @@ def parse_args(): parser.add_argument( "--sample_image_size", type=int, - default=512, + default=768, help="Size of sample images", ) parser.add_argument( @@ -297,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=25, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -459,7 +459,7 @@ class Checkpointer: torch.cuda.empty_cache() @torch.no_grad() - def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") unwrapped_unet = self.accelerator.unwrap_model( @@ -474,13 +474,14 @@ class Checkpointer: scheduler=self.scheduler, ).to(self.accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) + pipeline.enable_vae_slicing() train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), device=pipeline.device, generator=generator, ) @@ -875,9 +876,7 @@ def main(): ) if accelerator.is_main_process: - checkpointer.save_samples( - 0, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(0, args.sample_steps) local_progress_bar = tqdm( range(num_update_steps_per_epoch + num_val_steps_per_epoch), @@ -1089,9 +1088,7 @@ def main(): max_acc_val = avg_acc_val if sample_checkpoint and accelerator.is_main_process: - checkpointer.save_samples( - global_step, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: -- cgit v1.2.3-54-g00ecf