summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
committerVolpeon <git@volpeon.ink>2022-09-30 14:13:51 +0200
commit9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch)
treead186862f5095663966dd1d42455023080aa0c4e /dreambooth.py
parentBetter sample file structure (diff)
downloadtextual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2
textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 39c4851..4d7366c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -59,7 +59,7 @@ def parse_args():
59 parser.add_argument( 59 parser.add_argument(
60 "--repeats", 60 "--repeats",
61 type=int, 61 type=int,
62 default=100, 62 default=1,
63 help="How many times to repeat the training data." 63 help="How many times to repeat the training data."
64 ) 64 )
65 parser.add_argument( 65 parser.add_argument(
@@ -375,7 +375,6 @@ class Checkpointer:
375 @torch.no_grad() 375 @torch.no_grad()
376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
377 samples_path = Path(self.output_dir).joinpath("samples") 377 samples_path = Path(self.output_dir).joinpath("samples")
378 samples_path.mkdir(parents=True, exist_ok=True)
379 378
380 unwrapped = self.accelerator.unwrap_model(self.unet) 379 unwrapped = self.accelerator.unwrap_model(self.unet)
381 pipeline = StableDiffusionPipeline( 380 pipeline = StableDiffusionPipeline(
@@ -403,6 +402,7 @@ class Checkpointer:
403 402
404 all_samples = [] 403 all_samples = []
405 file_path = samples_path.joinpath("stable", f"step_{step}.png") 404 file_path = samples_path.joinpath("stable", f"step_{step}.png")
405 file_path.parent.mkdir(parents=True, exist_ok=True)
406 406
407 data_enum = enumerate(val_data) 407 data_enum = enumerate(val_data)
408 408
@@ -436,6 +436,7 @@ class Checkpointer:
436 for data, pool in [(val_data, "val"), (train_data, "train")]: 436 for data, pool in [(val_data, "val"), (train_data, "train")]:
437 all_samples = [] 437 all_samples = []
438 file_path = samples_path.joinpath(pool, f"step_{step}.png") 438 file_path = samples_path.joinpath(pool, f"step_{step}.png")
439 file_path.parent.mkdir(parents=True, exist_ok=True)
439 440
440 data_enum = enumerate(data) 441 data_enum = enumerate(data)
441 442
@@ -496,11 +497,15 @@ def main():
496 cur_class_images = len(list(class_images_dir.iterdir())) 497 cur_class_images = len(list(class_images_dir.iterdir()))
497 498
498 if cur_class_images < args.num_class_images: 499 if cur_class_images < args.num_class_images:
499 torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 500 torch_dtype = torch.float32
501 if accelerator.device.type == "cuda":
502 torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision]
503
500 pipeline = StableDiffusionPipeline.from_pretrained( 504 pipeline = StableDiffusionPipeline.from_pretrained(
501 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 505 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
502 pipeline.enable_attention_slicing() 506 pipeline.enable_attention_slicing()
503 pipeline.set_progress_bar_config(disable=True) 507 pipeline.set_progress_bar_config(disable=True)
508 pipeline.to(accelerator.device)
504 509
505 num_new_images = args.num_class_images - cur_class_images 510 num_new_images = args.num_class_images - cur_class_images
506 logger.info(f"Number of class images to sample: {num_new_images}.") 511 logger.info(f"Number of class images to sample: {num_new_images}.")
@@ -509,7 +514,6 @@ def main():
509 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 514 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
510 515
511 sample_dataloader = accelerator.prepare(sample_dataloader) 516 sample_dataloader = accelerator.prepare(sample_dataloader)
512 pipeline.to(accelerator.device)
513 517
514 for example in tqdm( 518 for example in tqdm(
515 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 519 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process