From 2af0d47b44fe02269b1378f7691d258d35544bb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 14:54:44 +0200 Subject: Fix small details --- dreambooth.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 24e6091..a26bea7 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -68,7 +68,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=2, + default=4, help="How many class images to generate per training image." ) parser.add_argument( @@ -106,7 +106,8 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100) + default=100 + ) parser.add_argument( "--max_train_steps", type=int, @@ -293,7 +294,7 @@ class Checkpointer: unet, tokenizer, text_encoder, - output_dir, + output_dir: Path, instance_identifier, sample_image_size, sample_batches, @@ -321,14 +322,14 @@ class Checkpointer: pipeline = VlpnStableDiffusion( text_encoder=self.text_encoder, vae=self.vae, - unet=self.accelerator.unwrap_model(self.unet), + unet=unwrapped, tokenizer=self.tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), ) pipeline.enable_attention_slicing() - pipeline.save_pretrained(f"{self.output_dir}/model") + pipeline.save_pretrained(self.output_dir.joinpath("model")) del unwrapped del pipeline @@ -524,7 +525,7 @@ def main(): tokenizer=tokenizer, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, - class_subdir="db_cls", + class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, @@ -580,21 +581,6 @@ def main(): train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() - checkpointer = Checkpointer( - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - output_dir=basepath, - instance_identifier=args.instance_identifier, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -613,7 +599,7 @@ def main(): unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - # Move vae and unet to device + # Move text_encoder and vae to device text_encoder.to(accelerator.device) vae.to(accelerator.device) @@ -649,6 +635,21 @@ def main(): global_step = 0 min_val_loss = np.inf + checkpointer = Checkpointer( + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + output_dir=basepath, + instance_identifier=args.instance_identifier, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: checkpointer.save_samples( 0, -- cgit v1.2.3-54-g00ecf