diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-30 14:13:51 +0200 |
| commit | 9a42def9fcfb9a5c5471d640253ed6c8f45c4973 (patch) | |
| tree | ad186862f5095663966dd1d42455023080aa0c4e /dreambooth.py | |
| parent | Better sample file structure (diff) | |
| download | textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.gz textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.tar.bz2 textual-inversion-diff-9a42def9fcfb9a5c5471d640253ed6c8f45c4973.zip | |
Added custom SD pipeline + euler_a scheduler
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index 39c4851..4d7366c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -59,7 +59,7 @@ def parse_args(): | |||
| 59 | parser.add_argument( | 59 | parser.add_argument( |
| 60 | "--repeats", | 60 | "--repeats", |
| 61 | type=int, | 61 | type=int, |
| 62 | default=100, | 62 | default=1, |
| 63 | help="How many times to repeat the training data." | 63 | help="How many times to repeat the training data." |
| 64 | ) | 64 | ) |
| 65 | parser.add_argument( | 65 | parser.add_argument( |
| @@ -375,7 +375,6 @@ class Checkpointer: | |||
| 375 | @torch.no_grad() | 375 | @torch.no_grad() |
| 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 377 | samples_path = Path(self.output_dir).joinpath("samples") | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
| 378 | samples_path.mkdir(parents=True, exist_ok=True) | ||
| 379 | 378 | ||
| 380 | unwrapped = self.accelerator.unwrap_model(self.unet) | 379 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 381 | pipeline = StableDiffusionPipeline( | 380 | pipeline = StableDiffusionPipeline( |
| @@ -403,6 +402,7 @@ class Checkpointer: | |||
| 403 | 402 | ||
| 404 | all_samples = [] | 403 | all_samples = [] |
| 405 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | 404 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
| 405 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 406 | 406 | ||
| 407 | data_enum = enumerate(val_data) | 407 | data_enum = enumerate(val_data) |
| 408 | 408 | ||
| @@ -436,6 +436,7 @@ class Checkpointer: | |||
| 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 436 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
| 437 | all_samples = [] | 437 | all_samples = [] |
| 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 438 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 439 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 439 | 440 | ||
| 440 | data_enum = enumerate(data) | 441 | data_enum = enumerate(data) |
| 441 | 442 | ||
| @@ -496,11 +497,15 @@ def main(): | |||
| 496 | cur_class_images = len(list(class_images_dir.iterdir())) | 497 | cur_class_images = len(list(class_images_dir.iterdir())) |
| 497 | 498 | ||
| 498 | if cur_class_images < args.num_class_images: | 499 | if cur_class_images < args.num_class_images: |
| 499 | torch_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32 | 500 | torch_dtype = torch.float32 |
| 501 | if accelerator.device.type == "cuda": | ||
| 502 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
| 503 | |||
| 500 | pipeline = StableDiffusionPipeline.from_pretrained( | 504 | pipeline = StableDiffusionPipeline.from_pretrained( |
| 501 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 505 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
| 502 | pipeline.enable_attention_slicing() | 506 | pipeline.enable_attention_slicing() |
| 503 | pipeline.set_progress_bar_config(disable=True) | 507 | pipeline.set_progress_bar_config(disable=True) |
| 508 | pipeline.to(accelerator.device) | ||
| 504 | 509 | ||
| 505 | num_new_images = args.num_class_images - cur_class_images | 510 | num_new_images = args.num_class_images - cur_class_images |
| 506 | logger.info(f"Number of class images to sample: {num_new_images}.") | 511 | logger.info(f"Number of class images to sample: {num_new_images}.") |
| @@ -509,7 +514,6 @@ def main(): | |||
| 509 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | 514 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) |
| 510 | 515 | ||
| 511 | sample_dataloader = accelerator.prepare(sample_dataloader) | 516 | sample_dataloader = accelerator.prepare(sample_dataloader) |
| 512 | pipeline.to(accelerator.device) | ||
| 513 | 517 | ||
| 514 | for example in tqdm( | 518 | for example in tqdm( |
| 515 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | 519 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process |
