From 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Sep 2022 14:13:51 +0200 Subject: Added custom SD pipeline + euler_a scheduler --- dreambooth.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 39c4851..4d7366c 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -59,7 +59,7 @@ def parse_args(): parser.add_argument( "--repeats", type=int, - default=100, + default=1, help="How many times to repeat the training data." ) parser.add_argument( @@ -375,7 +375,6 @@ class Checkpointer: @torch.no_grad() def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - samples_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( @@ -403,6 +402,7 @@ class Checkpointer: all_samples = [] file_path = samples_path.joinpath("stable", f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(val_data) @@ -436,6 +436,7 @@ class Checkpointer: for data, pool in [(val_data, "val"), (train_data, "train")]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) @@ -496,11 +497,15 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.float32 + if accelerator.device.type == "cuda": + torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] + pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(disable=True) + pipeline.to(accelerator.device) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") @@ -509,7 +514,6 @@ def main(): sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) sample_dataloader = accelerator.prepare(sample_dataloader) - pipeline.to(accelerator.device) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process -- cgit v1.2.3-54-g00ecf