summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py10
-rw-r--r--dreambooth.py104
-rw-r--r--textual_inversion.py2
3 files changed, 50 insertions, 66 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 14c13bb..85ed4a5 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -108,14 +108,14 @@ class CSVDataset(Dataset):
108 else: 108 else:
109 self.class_data_root = None 109 self.class_data_root = None
110 110
111 self.interpolation = {"linear": PIL.Image.LINEAR, 111 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
112 "bilinear": PIL.Image.BILINEAR, 112 "bilinear": transforms.InterpolationMode.BILINEAR,
113 "bicubic": PIL.Image.BICUBIC, 113 "bicubic": transforms.InterpolationMode.BICUBIC,
114 "lanczos": PIL.Image.LANCZOS, 114 "lanczos": transforms.InterpolationMode.LANCZOS,
115 }[interpolation] 115 }[interpolation]
116 self.image_transforms = transforms.Compose( 116 self.image_transforms = transforms.Compose(
117 [ 117 [
118 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 118 transforms.Resize(size, interpolation=self.interpolation),
119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
120 transforms.RandomHorizontalFlip(), 120 transforms.RandomHorizontalFlip(),
121 transforms.ToTensor(), 121 transforms.ToTensor(),
diff --git a/dreambooth.py b/dreambooth.py
index 170b8e9..2df6858 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--learning_rate", 113 "--learning_rate",
114 type=float, 114 type=float,
115 default=5e-6, 115 default=3e-6,
116 help="Initial learning rate (after the potential warmup period) to use.", 116 help="Initial learning rate (after the potential warmup period) to use.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -183,12 +183,6 @@ def parse_args():
183 help="For distributed training: local_rank" 183 help="For distributed training: local_rank"
184 ) 184 )
185 parser.add_argument( 185 parser.add_argument(
186 "--checkpoint_frequency",
187 type=int,
188 default=200,
189 help="How often to save a checkpoint and sample image",
190 )
191 parser.add_argument(
192 "--sample_image_size", 186 "--sample_image_size",
193 type=int, 187 type=int,
194 default=512, 188 default=512,
@@ -379,8 +373,8 @@ class Checkpointer:
379 torch.cuda.empty_cache() 373 torch.cuda.empty_cache()
380 374
381 @torch.no_grad() 375 @torch.no_grad()
382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
383 samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) 377 samples_path = Path(self.output_dir).joinpath("samples")
384 samples_path.mkdir(parents=True, exist_ok=True) 378 samples_path.mkdir(parents=True, exist_ok=True)
385 379
386 unwrapped = self.accelerator.unwrap_model(self.unet) 380 unwrapped = self.accelerator.unwrap_model(self.unet)
@@ -397,12 +391,10 @@ class Checkpointer:
397 ).to(self.accelerator.device) 391 ).to(self.accelerator.device)
398 pipeline.enable_attention_slicing() 392 pipeline.enable_attention_slicing()
399 393
400 data = { 394 train_data = self.datamodule.train_dataloader()
401 "training": self.datamodule.train_dataloader(), 395 val_data = self.datamodule.val_dataloader()
402 "validation": self.datamodule.val_dataloader(),
403 }[mode]
404 396
405 if mode == "validation" and self.stable_sample_batches > 0 and step > 0: 397 if self.stable_sample_batches > 0:
406 stable_latents = torch.randn( 398 stable_latents = torch.randn(
407 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 399 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
408 device=pipeline.device, 400 device=pipeline.device,
@@ -410,14 +402,14 @@ class Checkpointer:
410 ) 402 )
411 403
412 all_samples = [] 404 all_samples = []
413 filename = f"stable_step_%d.png" % (step) 405 filename = f"step_{step}_val_stable.png"
414 406
415 data_enum = enumerate(data) 407 data_enum = enumerate(val_data)
416 408
417 # Generate and save stable samples 409 # Generate and save stable samples
418 for i in range(0, self.stable_sample_batches): 410 for i in range(0, self.stable_sample_batches):
419 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 411 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
420 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 412 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
421 413
422 with self.accelerator.autocast(): 414 with self.accelerator.autocast():
423 samples = pipeline( 415 samples = pipeline(
@@ -441,35 +433,35 @@ class Checkpointer:
441 del image_grid 433 del image_grid
442 del stable_latents 434 del stable_latents
443 435
444 all_samples = [] 436 for data, pool in [(train_data, "train"), (val_data, "val")]:
445 filename = f"step_%d.png" % (step) 437 all_samples = []
438 filename = f"step_{step}_{pool}.png"
446 439
447 data_enum = enumerate(data) 440 data_enum = enumerate(data)
448 441
449 # Generate and save random samples 442 for i in range(0, self.random_sample_batches):
450 for i in range(0, self.random_sample_batches): 443 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
451 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 444 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
452 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
453 445
454 with self.accelerator.autocast(): 446 with self.accelerator.autocast():
455 samples = pipeline( 447 samples = pipeline(
456 prompt=prompt, 448 prompt=prompt,
457 height=self.sample_image_size, 449 height=self.sample_image_size,
458 width=self.sample_image_size, 450 width=self.sample_image_size,
459 guidance_scale=guidance_scale, 451 guidance_scale=guidance_scale,
460 eta=eta, 452 eta=eta,
461 num_inference_steps=num_inference_steps, 453 num_inference_steps=num_inference_steps,
462 output_type='pil' 454 output_type='pil'
463 )["sample"] 455 )["sample"]
464 456
465 all_samples += samples 457 all_samples += samples
466 del samples 458 del samples
467 459
468 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 460 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
469 image_grid.save(f"{samples_path}/{filename}") 461 image_grid.save(f"{samples_path}/{filename}")
470 462
471 del all_samples 463 del all_samples
472 del image_grid 464 del image_grid
473 465
474 del unwrapped 466 del unwrapped
475 del pipeline 467 del pipeline
@@ -594,8 +586,7 @@ def main():
594 beta_start=0.00085, 586 beta_start=0.00085,
595 beta_end=0.012, 587 beta_end=0.012,
596 beta_schedule="scaled_linear", 588 beta_schedule="scaled_linear",
597 num_train_timesteps=1000, 589 num_train_timesteps=1000
598 tensor_format="pt"
599 ) 590 )
600 591
601 def collate_fn(examples): 592 def collate_fn(examples):
@@ -687,6 +678,7 @@ def main():
687 678
688 num_val_steps_per_epoch = len(val_dataloader) 679 num_val_steps_per_epoch = len(val_dataloader)
689 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 680 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
681 val_steps = num_val_steps_per_epoch * num_epochs
690 682
691 # We need to initialize the trackers we use, and also store our configuration. 683 # We need to initialize the trackers we use, and also store our configuration.
692 # The trackers initializes automatically on the main process. 684 # The trackers initializes automatically on the main process.
@@ -707,16 +699,16 @@ def main():
707 global_step = 0 699 global_step = 0
708 min_val_loss = np.inf 700 min_val_loss = np.inf
709 701
710 checkpointer.save_samples( 702 if accelerator.is_main_process:
711 "validation", 703 checkpointer.save_samples(
712 0, 704 0,
713 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 705 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
714 706
715 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 707 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
716 disable=not accelerator.is_local_main_process) 708 disable=not accelerator.is_local_main_process)
717 local_progress_bar.set_description("Batch X out of Y") 709 local_progress_bar.set_description("Batch X out of Y")
718 710
719 global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 711 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
720 global_progress_bar.set_description("Total progress") 712 global_progress_bar.set_description("Total progress")
721 713
722 try: 714 try:
@@ -789,15 +781,6 @@ def main():
789 781
790 global_step += 1 782 global_step += 1
791 783
792 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
793 local_progress_bar.clear()
794 global_progress_bar.clear()
795
796 checkpointer.save_samples(
797 "training",
798 global_step,
799 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
800
801 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 784 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
802 local_progress_bar.set_postfix(**logs) 785 local_progress_bar.set_postfix(**logs)
803 786
@@ -847,6 +830,7 @@ def main():
847 830
848 if accelerator.sync_gradients: 831 if accelerator.sync_gradients:
849 local_progress_bar.update(1) 832 local_progress_bar.update(1)
833 global_progress_bar.update(1)
850 834
851 logs = {"mode": "validation", "loss": loss} 835 logs = {"mode": "validation", "loss": loss}
852 local_progress_bar.set_postfix(**logs) 836 local_progress_bar.set_postfix(**logs)
@@ -862,10 +846,10 @@ def main():
862 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 846 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
863 min_val_loss = val_loss 847 min_val_loss = val_loss
864 848
865 checkpointer.save_samples( 849 if accelerator.is_main_process:
866 "validation", 850 checkpointer.save_samples(
867 global_step, 851 global_step,
868 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 852 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
869 853
870 accelerator.wait_for_everyone() 854 accelerator.wait_for_everyone()
871 855
diff --git a/textual_inversion.py b/textual_inversion.py
index 81f1cf5..399d876 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -581,7 +581,7 @@ def main():
581 581
582 # TODO (patil-suraj): laod scheduler using args 582 # TODO (patil-suraj): laod scheduler using args
583 noise_scheduler = DDPMScheduler( 583 noise_scheduler = DDPMScheduler(
584 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 584 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
585 ) 585 )
586 586
587 datamodule = CSVDataModule( 587 datamodule = CSVDataModule(