diff options
-rw-r--r-- | data/dreambooth/csv.py | 10 | ||||
-rw-r--r-- | dreambooth.py | 104 | ||||
-rw-r--r-- | textual_inversion.py | 2 |
3 files changed, 50 insertions, 66 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 14c13bb..85ed4a5 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -108,14 +108,14 @@ class CSVDataset(Dataset): | |||
108 | else: | 108 | else: |
109 | self.class_data_root = None | 109 | self.class_data_root = None |
110 | 110 | ||
111 | self.interpolation = {"linear": PIL.Image.LINEAR, | 111 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
112 | "bilinear": PIL.Image.BILINEAR, | 112 | "bilinear": transforms.InterpolationMode.BILINEAR, |
113 | "bicubic": PIL.Image.BICUBIC, | 113 | "bicubic": transforms.InterpolationMode.BICUBIC, |
114 | "lanczos": PIL.Image.LANCZOS, | 114 | "lanczos": transforms.InterpolationMode.LANCZOS, |
115 | }[interpolation] | 115 | }[interpolation] |
116 | self.image_transforms = transforms.Compose( | 116 | self.image_transforms = transforms.Compose( |
117 | [ | 117 | [ |
118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | 118 | transforms.Resize(size, interpolation=self.interpolation), |
119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
120 | transforms.RandomHorizontalFlip(), | 120 | transforms.RandomHorizontalFlip(), |
121 | transforms.ToTensor(), | 121 | transforms.ToTensor(), |
diff --git a/dreambooth.py b/dreambooth.py index 170b8e9..2df6858 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -112,7 +112,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 112 | parser.add_argument( |
113 | "--learning_rate", | 113 | "--learning_rate", |
114 | type=float, | 114 | type=float, |
115 | default=5e-6, | 115 | default=3e-6, |
116 | help="Initial learning rate (after the potential warmup period) to use.", | 116 | help="Initial learning rate (after the potential warmup period) to use.", |
117 | ) | 117 | ) |
118 | parser.add_argument( | 118 | parser.add_argument( |
@@ -183,12 +183,6 @@ def parse_args(): | |||
183 | help="For distributed training: local_rank" | 183 | help="For distributed training: local_rank" |
184 | ) | 184 | ) |
185 | parser.add_argument( | 185 | parser.add_argument( |
186 | "--checkpoint_frequency", | ||
187 | type=int, | ||
188 | default=200, | ||
189 | help="How often to save a checkpoint and sample image", | ||
190 | ) | ||
191 | parser.add_argument( | ||
192 | "--sample_image_size", | 186 | "--sample_image_size", |
193 | type=int, | 187 | type=int, |
194 | default=512, | 188 | default=512, |
@@ -379,8 +373,8 @@ class Checkpointer: | |||
379 | torch.cuda.empty_cache() | 373 | torch.cuda.empty_cache() |
380 | 374 | ||
381 | @torch.no_grad() | 375 | @torch.no_grad() |
382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
383 | samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
384 | samples_path.mkdir(parents=True, exist_ok=True) | 378 | samples_path.mkdir(parents=True, exist_ok=True) |
385 | 379 | ||
386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 380 | unwrapped = self.accelerator.unwrap_model(self.unet) |
@@ -397,12 +391,10 @@ class Checkpointer: | |||
397 | ).to(self.accelerator.device) | 391 | ).to(self.accelerator.device) |
398 | pipeline.enable_attention_slicing() | 392 | pipeline.enable_attention_slicing() |
399 | 393 | ||
400 | data = { | 394 | train_data = self.datamodule.train_dataloader() |
401 | "training": self.datamodule.train_dataloader(), | 395 | val_data = self.datamodule.val_dataloader() |
402 | "validation": self.datamodule.val_dataloader(), | ||
403 | }[mode] | ||
404 | 396 | ||
405 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: | 397 | if self.stable_sample_batches > 0: |
406 | stable_latents = torch.randn( | 398 | stable_latents = torch.randn( |
407 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 399 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
408 | device=pipeline.device, | 400 | device=pipeline.device, |
@@ -410,14 +402,14 @@ class Checkpointer: | |||
410 | ) | 402 | ) |
411 | 403 | ||
412 | all_samples = [] | 404 | all_samples = [] |
413 | filename = f"stable_step_%d.png" % (step) | 405 | filename = f"step_{step}_val_stable.png" |
414 | 406 | ||
415 | data_enum = enumerate(data) | 407 | data_enum = enumerate(val_data) |
416 | 408 | ||
417 | # Generate and save stable samples | 409 | # Generate and save stable samples |
418 | for i in range(0, self.stable_sample_batches): | 410 | for i in range(0, self.stable_sample_batches): |
419 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
420 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
421 | 413 | ||
422 | with self.accelerator.autocast(): | 414 | with self.accelerator.autocast(): |
423 | samples = pipeline( | 415 | samples = pipeline( |
@@ -441,35 +433,35 @@ class Checkpointer: | |||
441 | del image_grid | 433 | del image_grid |
442 | del stable_latents | 434 | del stable_latents |
443 | 435 | ||
444 | all_samples = [] | 436 | for data, pool in [(train_data, "train"), (val_data, "val")]: |
445 | filename = f"step_%d.png" % (step) | 437 | all_samples = [] |
438 | filename = f"step_{step}_{pool}.png" | ||
446 | 439 | ||
447 | data_enum = enumerate(data) | 440 | data_enum = enumerate(data) |
448 | 441 | ||
449 | # Generate and save random samples | 442 | for i in range(0, self.random_sample_batches): |
450 | for i in range(0, self.random_sample_batches): | 443 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
451 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 444 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
452 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | ||
453 | 445 | ||
454 | with self.accelerator.autocast(): | 446 | with self.accelerator.autocast(): |
455 | samples = pipeline( | 447 | samples = pipeline( |
456 | prompt=prompt, | 448 | prompt=prompt, |
457 | height=self.sample_image_size, | 449 | height=self.sample_image_size, |
458 | width=self.sample_image_size, | 450 | width=self.sample_image_size, |
459 | guidance_scale=guidance_scale, | 451 | guidance_scale=guidance_scale, |
460 | eta=eta, | 452 | eta=eta, |
461 | num_inference_steps=num_inference_steps, | 453 | num_inference_steps=num_inference_steps, |
462 | output_type='pil' | 454 | output_type='pil' |
463 | )["sample"] | 455 | )["sample"] |
464 | 456 | ||
465 | all_samples += samples | 457 | all_samples += samples |
466 | del samples | 458 | del samples |
467 | 459 | ||
468 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 460 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
469 | image_grid.save(f"{samples_path}/{filename}") | 461 | image_grid.save(f"{samples_path}/{filename}") |
470 | 462 | ||
471 | del all_samples | 463 | del all_samples |
472 | del image_grid | 464 | del image_grid |
473 | 465 | ||
474 | del unwrapped | 466 | del unwrapped |
475 | del pipeline | 467 | del pipeline |
@@ -594,8 +586,7 @@ def main(): | |||
594 | beta_start=0.00085, | 586 | beta_start=0.00085, |
595 | beta_end=0.012, | 587 | beta_end=0.012, |
596 | beta_schedule="scaled_linear", | 588 | beta_schedule="scaled_linear", |
597 | num_train_timesteps=1000, | 589 | num_train_timesteps=1000 |
598 | tensor_format="pt" | ||
599 | ) | 590 | ) |
600 | 591 | ||
601 | def collate_fn(examples): | 592 | def collate_fn(examples): |
@@ -687,6 +678,7 @@ def main(): | |||
687 | 678 | ||
688 | num_val_steps_per_epoch = len(val_dataloader) | 679 | num_val_steps_per_epoch = len(val_dataloader) |
689 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 680 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
681 | val_steps = num_val_steps_per_epoch * num_epochs | ||
690 | 682 | ||
691 | # We need to initialize the trackers we use, and also store our configuration. | 683 | # We need to initialize the trackers we use, and also store our configuration. |
692 | # The trackers initializes automatically on the main process. | 684 | # The trackers initializes automatically on the main process. |
@@ -707,16 +699,16 @@ def main(): | |||
707 | global_step = 0 | 699 | global_step = 0 |
708 | min_val_loss = np.inf | 700 | min_val_loss = np.inf |
709 | 701 | ||
710 | checkpointer.save_samples( | 702 | if accelerator.is_main_process: |
711 | "validation", | 703 | checkpointer.save_samples( |
712 | 0, | 704 | 0, |
713 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 705 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
714 | 706 | ||
715 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 707 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
716 | disable=not accelerator.is_local_main_process) | 708 | disable=not accelerator.is_local_main_process) |
717 | local_progress_bar.set_description("Batch X out of Y") | 709 | local_progress_bar.set_description("Batch X out of Y") |
718 | 710 | ||
719 | global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 711 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) |
720 | global_progress_bar.set_description("Total progress") | 712 | global_progress_bar.set_description("Total progress") |
721 | 713 | ||
722 | try: | 714 | try: |
@@ -789,15 +781,6 @@ def main(): | |||
789 | 781 | ||
790 | global_step += 1 | 782 | global_step += 1 |
791 | 783 | ||
792 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
793 | local_progress_bar.clear() | ||
794 | global_progress_bar.clear() | ||
795 | |||
796 | checkpointer.save_samples( | ||
797 | "training", | ||
798 | global_step, | ||
799 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
800 | |||
801 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 784 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
802 | local_progress_bar.set_postfix(**logs) | 785 | local_progress_bar.set_postfix(**logs) |
803 | 786 | ||
@@ -847,6 +830,7 @@ def main(): | |||
847 | 830 | ||
848 | if accelerator.sync_gradients: | 831 | if accelerator.sync_gradients: |
849 | local_progress_bar.update(1) | 832 | local_progress_bar.update(1) |
833 | global_progress_bar.update(1) | ||
850 | 834 | ||
851 | logs = {"mode": "validation", "loss": loss} | 835 | logs = {"mode": "validation", "loss": loss} |
852 | local_progress_bar.set_postfix(**logs) | 836 | local_progress_bar.set_postfix(**logs) |
@@ -862,10 +846,10 @@ def main(): | |||
862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 846 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
863 | min_val_loss = val_loss | 847 | min_val_loss = val_loss |
864 | 848 | ||
865 | checkpointer.save_samples( | 849 | if accelerator.is_main_process: |
866 | "validation", | 850 | checkpointer.save_samples( |
867 | global_step, | 851 | global_step, |
868 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 852 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
869 | 853 | ||
870 | accelerator.wait_for_everyone() | 854 | accelerator.wait_for_everyone() |
871 | 855 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index 81f1cf5..399d876 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -581,7 +581,7 @@ def main(): | |||
581 | 581 | ||
582 | # TODO (patil-suraj): laod scheduler using args | 582 | # TODO (patil-suraj): laod scheduler using args |
583 | noise_scheduler = DDPMScheduler( | 583 | noise_scheduler = DDPMScheduler( |
584 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | 584 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 |
585 | ) | 585 | ) |
586 | 586 | ||
587 | datamodule = CSVDataModule( | 587 | datamodule = CSVDataModule( |