summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-28 15:05:30 +0200
committerVolpeon <git@volpeon.ink>2022-09-28 15:05:30 +0200
commit5b54788842cdd7b342bd60d6944158009130b4d4 (patch)
treeebeb9fbb2c2e4d4d2e406d3ad86b01deb2b1525a /dreambooth.py
parentInfer script: Store args, better path/file names (diff)
downloadtextual-inversion-diff-5b54788842cdd7b342bd60d6944158009130b4d4.tar.gz
textual-inversion-diff-5b54788842cdd7b342bd60d6944158009130b4d4.tar.bz2
textual-inversion-diff-5b54788842cdd7b342bd60d6944158009130b4d4.zip
Improved sample output and progress bars
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py104
1 files changed, 44 insertions, 60 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 170b8e9..2df6858 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--learning_rate", 113 "--learning_rate",
114 type=float, 114 type=float,
115 default=5e-6, 115 default=3e-6,
116 help="Initial learning rate (after the potential warmup period) to use.", 116 help="Initial learning rate (after the potential warmup period) to use.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -183,12 +183,6 @@ def parse_args():
183 help="For distributed training: local_rank" 183 help="For distributed training: local_rank"
184 ) 184 )
185 parser.add_argument( 185 parser.add_argument(
186 "--checkpoint_frequency",
187 type=int,
188 default=200,
189 help="How often to save a checkpoint and sample image",
190 )
191 parser.add_argument(
192 "--sample_image_size", 186 "--sample_image_size",
193 type=int, 187 type=int,
194 default=512, 188 default=512,
@@ -379,8 +373,8 @@ class Checkpointer:
379 torch.cuda.empty_cache() 373 torch.cuda.empty_cache()
380 374
381 @torch.no_grad() 375 @torch.no_grad()
382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): 376 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
383 samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) 377 samples_path = Path(self.output_dir).joinpath("samples")
384 samples_path.mkdir(parents=True, exist_ok=True) 378 samples_path.mkdir(parents=True, exist_ok=True)
385 379
386 unwrapped = self.accelerator.unwrap_model(self.unet) 380 unwrapped = self.accelerator.unwrap_model(self.unet)
@@ -397,12 +391,10 @@ class Checkpointer:
397 ).to(self.accelerator.device) 391 ).to(self.accelerator.device)
398 pipeline.enable_attention_slicing() 392 pipeline.enable_attention_slicing()
399 393
400 data = { 394 train_data = self.datamodule.train_dataloader()
401 "training": self.datamodule.train_dataloader(), 395 val_data = self.datamodule.val_dataloader()
402 "validation": self.datamodule.val_dataloader(),
403 }[mode]
404 396
405 if mode == "validation" and self.stable_sample_batches > 0 and step > 0: 397 if self.stable_sample_batches > 0:
406 stable_latents = torch.randn( 398 stable_latents = torch.randn(
407 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 399 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
408 device=pipeline.device, 400 device=pipeline.device,
@@ -410,14 +402,14 @@ class Checkpointer:
410 ) 402 )
411 403
412 all_samples = [] 404 all_samples = []
413 filename = f"stable_step_%d.png" % (step) 405 filename = f"step_{step}_val_stable.png"
414 406
415 data_enum = enumerate(data) 407 data_enum = enumerate(val_data)
416 408
417 # Generate and save stable samples 409 # Generate and save stable samples
418 for i in range(0, self.stable_sample_batches): 410 for i in range(0, self.stable_sample_batches):
419 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 411 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
420 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 412 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
421 413
422 with self.accelerator.autocast(): 414 with self.accelerator.autocast():
423 samples = pipeline( 415 samples = pipeline(
@@ -441,35 +433,35 @@ class Checkpointer:
441 del image_grid 433 del image_grid
442 del stable_latents 434 del stable_latents
443 435
444 all_samples = [] 436 for data, pool in [(train_data, "train"), (val_data, "val")]:
445 filename = f"step_%d.png" % (step) 437 all_samples = []
438 filename = f"step_{step}_{pool}.png"
446 439
447 data_enum = enumerate(data) 440 data_enum = enumerate(data)
448 441
449 # Generate and save random samples 442 for i in range(0, self.random_sample_batches):
450 for i in range(0, self.random_sample_batches): 443 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
451 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 444 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
452 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
453 445
454 with self.accelerator.autocast(): 446 with self.accelerator.autocast():
455 samples = pipeline( 447 samples = pipeline(
456 prompt=prompt, 448 prompt=prompt,
457 height=self.sample_image_size, 449 height=self.sample_image_size,
458 width=self.sample_image_size, 450 width=self.sample_image_size,
459 guidance_scale=guidance_scale, 451 guidance_scale=guidance_scale,
460 eta=eta, 452 eta=eta,
461 num_inference_steps=num_inference_steps, 453 num_inference_steps=num_inference_steps,
462 output_type='pil' 454 output_type='pil'
463 )["sample"] 455 )["sample"]
464 456
465 all_samples += samples 457 all_samples += samples
466 del samples 458 del samples
467 459
468 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 460 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
469 image_grid.save(f"{samples_path}/{filename}") 461 image_grid.save(f"{samples_path}/{filename}")
470 462
471 del all_samples 463 del all_samples
472 del image_grid 464 del image_grid
473 465
474 del unwrapped 466 del unwrapped
475 del pipeline 467 del pipeline
@@ -594,8 +586,7 @@ def main():
594 beta_start=0.00085, 586 beta_start=0.00085,
595 beta_end=0.012, 587 beta_end=0.012,
596 beta_schedule="scaled_linear", 588 beta_schedule="scaled_linear",
597 num_train_timesteps=1000, 589 num_train_timesteps=1000
598 tensor_format="pt"
599 ) 590 )
600 591
601 def collate_fn(examples): 592 def collate_fn(examples):
@@ -687,6 +678,7 @@ def main():
687 678
688 num_val_steps_per_epoch = len(val_dataloader) 679 num_val_steps_per_epoch = len(val_dataloader)
689 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 680 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
681 val_steps = num_val_steps_per_epoch * num_epochs
690 682
691 # We need to initialize the trackers we use, and also store our configuration. 683 # We need to initialize the trackers we use, and also store our configuration.
692 # The trackers initializes automatically on the main process. 684 # The trackers initializes automatically on the main process.
@@ -707,16 +699,16 @@ def main():
707 global_step = 0 699 global_step = 0
708 min_val_loss = np.inf 700 min_val_loss = np.inf
709 701
710 checkpointer.save_samples( 702 if accelerator.is_main_process:
711 "validation", 703 checkpointer.save_samples(
712 0, 704 0,
713 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 705 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
714 706
715 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 707 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
716 disable=not accelerator.is_local_main_process) 708 disable=not accelerator.is_local_main_process)
717 local_progress_bar.set_description("Batch X out of Y") 709 local_progress_bar.set_description("Batch X out of Y")
718 710
719 global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 711 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
720 global_progress_bar.set_description("Total progress") 712 global_progress_bar.set_description("Total progress")
721 713
722 try: 714 try:
@@ -789,15 +781,6 @@ def main():
789 781
790 global_step += 1 782 global_step += 1
791 783
792 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
793 local_progress_bar.clear()
794 global_progress_bar.clear()
795
796 checkpointer.save_samples(
797 "training",
798 global_step,
799 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
800
801 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 784 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
802 local_progress_bar.set_postfix(**logs) 785 local_progress_bar.set_postfix(**logs)
803 786
@@ -847,6 +830,7 @@ def main():
847 830
848 if accelerator.sync_gradients: 831 if accelerator.sync_gradients:
849 local_progress_bar.update(1) 832 local_progress_bar.update(1)
833 global_progress_bar.update(1)
850 834
851 logs = {"mode": "validation", "loss": loss} 835 logs = {"mode": "validation", "loss": loss}
852 local_progress_bar.set_postfix(**logs) 836 local_progress_bar.set_postfix(**logs)
@@ -862,10 +846,10 @@ def main():
862 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 846 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
863 min_val_loss = val_loss 847 min_val_loss = val_loss
864 848
865 checkpointer.save_samples( 849 if accelerator.is_main_process:
866 "validation", 850 checkpointer.save_samples(
867 global_step, 851 global_step,
868 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 852 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
869 853
870 accelerator.wait_for_everyone() 854 accelerator.wait_for_everyone()
871 855