summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py48
1 files changed, 24 insertions, 24 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 7c0d32c..170b8e9 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -2,6 +2,7 @@ import argparse
2import math 2import math
3import os 3import os
4import datetime 4import datetime
5import logging
5from pathlib import Path 6from pathlib import Path
6 7
7import numpy as np 8import numpy as np
@@ -20,7 +21,6 @@ from tqdm.auto import tqdm
20from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 21from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21from slugify import slugify 22from slugify import slugify
22import json 23import json
23import os
24 24
25from data.dreambooth.csv import CSVDataModule 25from data.dreambooth.csv import CSVDataModule
26from data.dreambooth.prompt import PromptDataset 26from data.dreambooth.prompt import PromptDataset
@@ -95,7 +95,7 @@ def parse_args():
95 parser.add_argument( 95 parser.add_argument(
96 "--max_train_steps", 96 "--max_train_steps",
97 type=int, 97 type=int,
98 default=1000, 98 default=600,
99 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 99 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
100 ) 100 )
101 parser.add_argument( 101 parser.add_argument(
@@ -380,8 +380,8 @@ class Checkpointer:
380 380
381 @torch.no_grad() 381 @torch.no_grad()
382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): 382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps):
383 samples_path = f"{self.output_dir}/samples/{mode}" 383 samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode)
384 os.makedirs(samples_path, exist_ok=True) 384 samples_path.mkdir(parents=True, exist_ok=True)
385 385
386 unwrapped = self.accelerator.unwrap_model(self.unet) 386 unwrapped = self.accelerator.unwrap_model(self.unet)
387 pipeline = StableDiffusionPipeline( 387 pipeline = StableDiffusionPipeline(
@@ -471,7 +471,6 @@ class Checkpointer:
471 del all_samples 471 del all_samples
472 del image_grid 472 del image_grid
473 473
474 del checker
475 del unwrapped 474 del unwrapped
476 del pipeline 475 del pipeline
477 476
@@ -483,8 +482,8 @@ def main():
483 args = parse_args() 482 args = parse_args()
484 483
485 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 484 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
486 basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" 485 basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now)
487 os.makedirs(basepath, exist_ok=True) 486 basepath.mkdir(parents=True, exist_ok=True)
488 487
489 accelerator = Accelerator( 488 accelerator = Accelerator(
490 log_with=LoggerType.TENSORBOARD, 489 log_with=LoggerType.TENSORBOARD,
@@ -493,6 +492,8 @@ def main():
493 mixed_precision=args.mixed_precision 492 mixed_precision=args.mixed_precision
494 ) 493 )
495 494
495 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
496
496 # If passed along, set the training seed now. 497 # If passed along, set the training seed now.
497 if args.seed is not None: 498 if args.seed is not None:
498 set_seed(args.seed) 499 set_seed(args.seed)
@@ -655,8 +656,7 @@ def main():
655 656
656 # Scheduler and math around the number of training steps. 657 # Scheduler and math around the number of training steps.
657 overrode_max_train_steps = False 658 overrode_max_train_steps = False
658 num_update_steps_per_epoch = math.ceil( 659 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
659 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
660 if args.max_train_steps is None: 660 if args.max_train_steps is None:
661 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 661 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
662 overrode_max_train_steps = True 662 overrode_max_train_steps = True
@@ -681,13 +681,12 @@ def main():
681 vae.eval() 681 vae.eval()
682 682
683 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 683 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
684 num_update_steps_per_epoch = math.ceil( 684 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
685 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
686 if overrode_max_train_steps: 685 if overrode_max_train_steps:
687 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 686 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
688 # Afterwards we recalculate our number of training epochs 687
689 args.num_train_epochs = math.ceil( 688 num_val_steps_per_epoch = len(val_dataloader)
690 args.max_train_steps / num_update_steps_per_epoch) 689 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
691 690
692 # We need to initialize the trackers we use, and also store our configuration. 691 # We need to initialize the trackers we use, and also store our configuration.
693 # The trackers initializes automatically on the main process. 692 # The trackers initializes automatically on the main process.
@@ -698,7 +697,7 @@ def main():
698 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 697 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
699 698
700 logger.info("***** Running training *****") 699 logger.info("***** Running training *****")
701 logger.info(f" Num Epochs = {args.num_train_epochs}") 700 logger.info(f" Num Epochs = {num_epochs}")
702 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 701 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
703 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 702 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
704 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 703 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
@@ -713,14 +712,16 @@ def main():
713 0, 712 0,
714 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 713 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
715 714
716 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 715 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
717 local_progress_bar.set_description("Steps ") 716 disable=not accelerator.is_local_main_process)
717 local_progress_bar.set_description("Batch X out of Y")
718 718
719 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 719 global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
720 progress_bar.set_description("Global steps") 720 global_progress_bar.set_description("Total progress")
721 721
722 try: 722 try:
723 for epoch in range(args.num_train_epochs): 723 for epoch in range(num_epochs):
724 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
724 local_progress_bar.reset() 725 local_progress_bar.reset()
725 726
726 unet.train() 727 unet.train()
@@ -784,13 +785,13 @@ def main():
784 # Checks if the accelerator has performed an optimization step behind the scenes 785 # Checks if the accelerator has performed an optimization step behind the scenes
785 if accelerator.sync_gradients: 786 if accelerator.sync_gradients:
786 local_progress_bar.update(1) 787 local_progress_bar.update(1)
787 progress_bar.update(1) 788 global_progress_bar.update(1)
788 789
789 global_step += 1 790 global_step += 1
790 791
791 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 792 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
792 local_progress_bar.clear() 793 local_progress_bar.clear()
793 progress_bar.clear() 794 global_progress_bar.clear()
794 795
795 checkpointer.save_samples( 796 checkpointer.save_samples(
796 "training", 797 "training",
@@ -846,7 +847,6 @@ def main():
846 847
847 if accelerator.sync_gradients: 848 if accelerator.sync_gradients:
848 local_progress_bar.update(1) 849 local_progress_bar.update(1)
849 progress_bar.update(1)
850 850
851 logs = {"mode": "validation", "loss": loss} 851 logs = {"mode": "validation", "loss": loss}
852 local_progress_bar.set_postfix(**logs) 852 local_progress_bar.set_postfix(**logs)
@@ -856,7 +856,7 @@ def main():
856 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 856 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
857 857
858 local_progress_bar.clear() 858 local_progress_bar.clear()
859 progress_bar.clear() 859 global_progress_bar.clear()
860 860
861 if min_val_loss > val_loss: 861 if min_val_loss > val_loss:
862 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 862 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")