diff options
| -rw-r--r-- | data/dreambooth/csv.py | 10 | ||||
| -rw-r--r-- | dreambooth.py | 104 | ||||
| -rw-r--r-- | textual_inversion.py | 2 |
3 files changed, 50 insertions, 66 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 14c13bb..85ed4a5 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -108,14 +108,14 @@ class CSVDataset(Dataset): | |||
| 108 | else: | 108 | else: |
| 109 | self.class_data_root = None | 109 | self.class_data_root = None |
| 110 | 110 | ||
| 111 | self.interpolation = {"linear": PIL.Image.LINEAR, | 111 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
| 112 | "bilinear": PIL.Image.BILINEAR, | 112 | "bilinear": transforms.InterpolationMode.BILINEAR, |
| 113 | "bicubic": PIL.Image.BICUBIC, | 113 | "bicubic": transforms.InterpolationMode.BICUBIC, |
| 114 | "lanczos": PIL.Image.LANCZOS, | 114 | "lanczos": transforms.InterpolationMode.LANCZOS, |
| 115 | }[interpolation] | 115 | }[interpolation] |
| 116 | self.image_transforms = transforms.Compose( | 116 | self.image_transforms = transforms.Compose( |
| 117 | [ | 117 | [ |
| 118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | 118 | transforms.Resize(size, interpolation=self.interpolation), |
| 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
| 120 | transforms.RandomHorizontalFlip(), | 120 | transforms.RandomHorizontalFlip(), |
| 121 | transforms.ToTensor(), | 121 | transforms.ToTensor(), |
diff --git a/dreambooth.py b/dreambooth.py index 170b8e9..2df6858 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--learning_rate", | 113 | "--learning_rate", |
| 114 | type=float, | 114 | type=float, |
| 115 | default=5e-6, | 115 | default=3e-6, |
| 116 | help="Initial learning rate (after the potential warmup period) to use.", | 116 | help="Initial learning rate (after the potential warmup period) to use.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -183,12 +183,6 @@ def parse_args(): | |||
| 183 | help="For distributed training: local_rank" | 183 | help="For distributed training: local_rank" |
| 184 | ) | 184 | ) |
| 185 | parser.add_argument( | 185 | parser.add_argument( |
| 186 | "--checkpoint_frequency", | ||
| 187 | type=int, | ||
| 188 | default=200, | ||
| 189 | help="How often to save a checkpoint and sample image", | ||
| 190 | ) | ||
| 191 | parser.add_argument( | ||
| 192 | "--sample_image_size", | 186 | "--sample_image_size", |
| 193 | type=int, | 187 | type=int, |
| 194 | default=512, | 188 | default=512, |
| @@ -379,8 +373,8 @@ class Checkpointer: | |||
| 379 | torch.cuda.empty_cache() | 373 | torch.cuda.empty_cache() |
| 380 | 374 | ||
| 381 | @torch.no_grad() | 375 | @torch.no_grad() |
| 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 383 | samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
| 384 | samples_path.mkdir(parents=True, exist_ok=True) | 378 | samples_path.mkdir(parents=True, exist_ok=True) |
| 385 | 379 | ||
| 386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 380 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| @@ -397,12 +391,10 @@ class Checkpointer: | |||
| 397 | ).to(self.accelerator.device) | 391 | ).to(self.accelerator.device) |
| 398 | pipeline.enable_attention_slicing() | 392 | pipeline.enable_attention_slicing() |
| 399 | 393 | ||
| 400 | data = { | 394 | train_data = self.datamodule.train_dataloader() |
| 401 | "training": self.datamodule.train_dataloader(), | 395 | val_data = self.datamodule.val_dataloader() |
| 402 | "validation": self.datamodule.val_dataloader(), | ||
| 403 | }[mode] | ||
| 404 | 396 | ||
| 405 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: | 397 | if self.stable_sample_batches > 0: |
| 406 | stable_latents = torch.randn( | 398 | stable_latents = torch.randn( |
| 407 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 399 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 408 | device=pipeline.device, | 400 | device=pipeline.device, |
| @@ -410,14 +402,14 @@ class Checkpointer: | |||
| 410 | ) | 402 | ) |
| 411 | 403 | ||
| 412 | all_samples = [] | 404 | all_samples = [] |
| 413 | filename = f"stable_step_%d.png" % (step) | 405 | filename = f"step_{step}_val_stable.png" |
| 414 | 406 | ||
| 415 | data_enum = enumerate(data) | 407 | data_enum = enumerate(val_data) |
| 416 | 408 | ||
| 417 | # Generate and save stable samples | 409 | # Generate and save stable samples |
| 418 | for i in range(0, self.stable_sample_batches): | 410 | for i in range(0, self.stable_sample_batches): |
| 419 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 420 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 421 | 413 | ||
| 422 | with self.accelerator.autocast(): | 414 | with self.accelerator.autocast(): |
| 423 | samples = pipeline( | 415 | samples = pipeline( |
| @@ -441,35 +433,35 @@ class Checkpointer: | |||
| 441 | del image_grid | 433 | del image_grid |
| 442 | del stable_latents | 434 | del stable_latents |
| 443 | 435 | ||
| 444 | all_samples = [] | 436 | for data, pool in [(train_data, "train"), (val_data, "val")]: |
| 445 | filename = f"step_%d.png" % (step) | 437 | all_samples = [] |
| 438 | filename = f"step_{step}_{pool}.png" | ||
| 446 | 439 | ||
| 447 | data_enum = enumerate(data) | 440 | data_enum = enumerate(data) |
| 448 | 441 | ||
| 449 | # Generate and save random samples | 442 | for i in range(0, self.random_sample_batches): |
| 450 | for i in range(0, self.random_sample_batches): | 443 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 451 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 444 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 452 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | ||
| 453 | 445 | ||
| 454 | with self.accelerator.autocast(): | 446 | with self.accelerator.autocast(): |
| 455 | samples = pipeline( | 447 | samples = pipeline( |
| 456 | prompt=prompt, | 448 | prompt=prompt, |
| 457 | height=self.sample_image_size, | 449 | height=self.sample_image_size, |
| 458 | width=self.sample_image_size, | 450 | width=self.sample_image_size, |
| 459 | guidance_scale=guidance_scale, | 451 | guidance_scale=guidance_scale, |
| 460 | eta=eta, | 452 | eta=eta, |
| 461 | num_inference_steps=num_inference_steps, | 453 | num_inference_steps=num_inference_steps, |
| 462 | output_type='pil' | 454 | output_type='pil' |
| 463 | )["sample"] | 455 | )["sample"] |
| 464 | 456 | ||
| 465 | all_samples += samples | 457 | all_samples += samples |
| 466 | del samples | 458 | del samples |
| 467 | 459 | ||
| 468 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
| 469 | image_grid.save(f"{samples_path}/{filename}") | 461 | image_grid.save(f"{samples_path}/{filename}") |
| 470 | 462 | ||
| 471 | del all_samples | 463 | del all_samples |
| 472 | del image_grid | 464 | del image_grid |
| 473 | 465 | ||
| 474 | del unwrapped | 466 | del unwrapped |
| 475 | del pipeline | 467 | del pipeline |
| @@ -594,8 +586,7 @@ def main(): | |||
| 594 | beta_start=0.00085, | 586 | beta_start=0.00085, |
| 595 | beta_end=0.012, | 587 | beta_end=0.012, |
| 596 | beta_schedule="scaled_linear", | 588 | beta_schedule="scaled_linear", |
| 597 | num_train_timesteps=1000, | 589 | num_train_timesteps=1000 |
| 598 | tensor_format="pt" | ||
| 599 | ) | 590 | ) |
| 600 | 591 | ||
| 601 | def collate_fn(examples): | 592 | def collate_fn(examples): |
| @@ -687,6 +678,7 @@ def main(): | |||
| 687 | 678 | ||
| 688 | num_val_steps_per_epoch = len(val_dataloader) | 679 | num_val_steps_per_epoch = len(val_dataloader) |
| 689 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 680 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 681 | val_steps = num_val_steps_per_epoch * num_epochs | ||
| 690 | 682 | ||
| 691 | # We need to initialize the trackers we use, and also store our configuration. | 683 | # We need to initialize the trackers we use, and also store our configuration. |
| 692 | # The trackers initializes automatically on the main process. | 684 | # The trackers initializes automatically on the main process. |
| @@ -707,16 +699,16 @@ def main(): | |||
| 707 | global_step = 0 | 699 | global_step = 0 |
| 708 | min_val_loss = np.inf | 700 | min_val_loss = np.inf |
| 709 | 701 | ||
| 710 | checkpointer.save_samples( | 702 | if accelerator.is_main_process: |
| 711 | "validation", | 703 | checkpointer.save_samples( |
| 712 | 0, | 704 | 0, |
| 713 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 705 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 714 | 706 | ||
| 715 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 707 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 716 | disable=not accelerator.is_local_main_process) | 708 | disable=not accelerator.is_local_main_process) |
| 717 | local_progress_bar.set_description("Batch X out of Y") | 709 | local_progress_bar.set_description("Batch X out of Y") |
| 718 | 710 | ||
| 719 | global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 711 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) |
| 720 | global_progress_bar.set_description("Total progress") | 712 | global_progress_bar.set_description("Total progress") |
| 721 | 713 | ||
| 722 | try: | 714 | try: |
| @@ -789,15 +781,6 @@ def main(): | |||
| 789 | 781 | ||
| 790 | global_step += 1 | 782 | global_step += 1 |
| 791 | 783 | ||
| 792 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
| 793 | local_progress_bar.clear() | ||
| 794 | global_progress_bar.clear() | ||
| 795 | |||
| 796 | checkpointer.save_samples( | ||
| 797 | "training", | ||
| 798 | global_step, | ||
| 799 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 800 | |||
| 801 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 784 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 802 | local_progress_bar.set_postfix(**logs) | 785 | local_progress_bar.set_postfix(**logs) |
| 803 | 786 | ||
| @@ -847,6 +830,7 @@ def main(): | |||
| 847 | 830 | ||
| 848 | if accelerator.sync_gradients: | 831 | if accelerator.sync_gradients: |
| 849 | local_progress_bar.update(1) | 832 | local_progress_bar.update(1) |
| 833 | global_progress_bar.update(1) | ||
| 850 | 834 | ||
| 851 | logs = {"mode": "validation", "loss": loss} | 835 | logs = {"mode": "validation", "loss": loss} |
| 852 | local_progress_bar.set_postfix(**logs) | 836 | local_progress_bar.set_postfix(**logs) |
| @@ -862,10 +846,10 @@ def main(): | |||
| 862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 846 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 863 | min_val_loss = val_loss | 847 | min_val_loss = val_loss |
| 864 | 848 | ||
| 865 | checkpointer.save_samples( | 849 | if accelerator.is_main_process: |
| 866 | "validation", | 850 | checkpointer.save_samples( |
| 867 | global_step, | 851 | global_step, |
| 868 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 852 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 869 | 853 | ||
| 870 | accelerator.wait_for_everyone() | 854 | accelerator.wait_for_everyone() |
| 871 | 855 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index 81f1cf5..399d876 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -581,7 +581,7 @@ def main(): | |||
| 581 | 581 | ||
| 582 | # TODO (patil-suraj): laod scheduler using args | 582 | # TODO (patil-suraj): laod scheduler using args |
| 583 | noise_scheduler = DDPMScheduler( | 583 | noise_scheduler = DDPMScheduler( |
| 584 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | 584 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 |
| 585 | ) | 585 | ) |
| 586 | 586 | ||
| 587 | datamodule = CSVDataModule( | 587 | datamodule = CSVDataModule( |
