diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 39c4851..4d7366c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -59,7 +59,7 @@ def parse_args(): | |||
59 | parser.add_argument( | 59 | parser.add_argument( |
60 | "--repeats", | 60 | "--repeats", |
61 | type=int, | 61 | type=int, |
62 | default=100, | 62 | default=1, |
63 | help="How many times to repeat the training data." | 63 | help="How many times to repeat the training data." |
64 | ) | 64 | ) |
65 | parser.add_argument( | 65 | parser.add_argument( |
@@ -375,7 +375,6 @@ class Checkpointer: | |||
375 | @torch.no_grad() | 375 | @torch.no_grad() |
376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
377 | samples_path = Path(self.output_dir).joinpath("samples") | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
378 | samples_path.mkdir(parents=True, exist_ok=True) | ||
379 | 378 | ||
380 | unwrapped = self.accelerator.unwrap_model(self.unet) | 379 | unwrapped = self.accelerator.unwrap_model(self.unet) |
381 | pipeline = StableDiffusionPipeline( | 380 | pipeline = StableDiffusionPipeline( |
@@ -403,6 +402,7 @@ class Checkpointer: | |||
403 | 402 | ||
404 | all_samples = [] | 403 | all_samples = [] |
405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | 404 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
405 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
406 | 406 | ||
407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
408 | 408 | ||
@@ -436,6 +436,7 @@ class Checkpointer: | |||
436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
437 | all_samples = [] | 437 | all_samples = [] |
438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
439 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
439 | 440 | ||
440 | data_enum = enumerate(data) | 441 | data_enum = enumerate(data) |
441 | 442 | ||
@@ -496,11 +497,15 @@ def main(): | |||
496 | cur_class_images = len(list(class_images_dir.iterdir())) | 497 | cur_class_images = len(list(class_images_dir.iterdir())) |
497 | 498 | ||
498 | if cur_class_images < args.num_class_images: | 499 | if cur_class_images < args.num_class_images: |
499 | torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 | 500 | torch_dtype = torch.float32 |
501 | if accelerator.device.type == "cuda": | ||
502 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
503 | |||
500 | pipeline = StableDiffusionPipeline.from_pretrained( | 504 | pipeline = StableDiffusionPipeline.from_pretrained( |
501 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 505 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
502 | pipeline.enable_attention_slicing() | 506 | pipeline.enable_attention_slicing() |
503 | pipeline.set_progress_bar_config(disable=True) | 507 | pipeline.set_progress_bar_config(disable=True) |
508 | pipeline.to(accelerator.device) | ||
504 | 509 | ||
505 | num_new_images = args.num_class_images - cur_class_images | 510 | num_new_images = args.num_class_images - cur_class_images |
506 | logger.info(f"Number of class images to sample: {num_new_images}.") | 511 | logger.info(f"Number of class images to sample: {num_new_images}.") |
@@ -509,7 +514,6 @@ def main(): | |||
509 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | 514 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) |
510 | 515 | ||
511 | sample_dataloader = accelerator.prepare(sample_dataloader) | 516 | sample_dataloader = accelerator.prepare(sample_dataloader) |
512 | pipeline.to(accelerator.device) | ||
513 | 517 | ||
514 | for example in tqdm( | 518 | for example in tqdm( |
515 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | 519 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process |