diff options
-rw-r--r-- | dreambooth.py | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7c0d32c..170b8e9 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -2,6 +2,7 @@ import argparse | |||
2 | import math | 2 | import math |
3 | import os | 3 | import os |
4 | import datetime | 4 | import datetime |
5 | import logging | ||
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | 7 | ||
7 | import numpy as np | 8 | import numpy as np |
@@ -20,7 +21,6 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 22 | from slugify import slugify |
22 | import json | 23 | import json |
23 | import os | ||
24 | 24 | ||
25 | from data.dreambooth.csv import CSVDataModule | 25 | from data.dreambooth.csv import CSVDataModule |
26 | from data.dreambooth.prompt import PromptDataset | 26 | from data.dreambooth.prompt import PromptDataset |
@@ -95,7 +95,7 @@ def parse_args(): | |||
95 | parser.add_argument( | 95 | parser.add_argument( |
96 | "--max_train_steps", | 96 | "--max_train_steps", |
97 | type=int, | 97 | type=int, |
98 | default=1000, | 98 | default=600, |
99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
100 | ) | 100 | ) |
101 | parser.add_argument( | 101 | parser.add_argument( |
@@ -380,8 +380,8 @@ class Checkpointer: | |||
380 | 380 | ||
381 | @torch.no_grad() | 381 | @torch.no_grad() |
382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): |
383 | samples_path = f"{self.output_dir}/samples/{mode}" | 383 | samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) |
384 | os.makedirs(samples_path, exist_ok=True) | 384 | samples_path.mkdir(parents=True, exist_ok=True) |
385 | 385 | ||
386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 386 | unwrapped = self.accelerator.unwrap_model(self.unet) |
387 | pipeline = StableDiffusionPipeline( | 387 | pipeline = StableDiffusionPipeline( |
@@ -471,7 +471,6 @@ class Checkpointer: | |||
471 | del all_samples | 471 | del all_samples |
472 | del image_grid | 472 | del image_grid |
473 | 473 | ||
474 | del checker | ||
475 | del unwrapped | 474 | del unwrapped |
476 | del pipeline | 475 | del pipeline |
477 | 476 | ||
@@ -483,8 +482,8 @@ def main(): | |||
483 | args = parse_args() | 482 | args = parse_args() |
484 | 483 | ||
485 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 484 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
486 | basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" | 485 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) |
487 | os.makedirs(basepath, exist_ok=True) | 486 | basepath.mkdir(parents=True, exist_ok=True) |
488 | 487 | ||
489 | accelerator = Accelerator( | 488 | accelerator = Accelerator( |
490 | log_with=LoggerType.TENSORBOARD, | 489 | log_with=LoggerType.TENSORBOARD, |
@@ -493,6 +492,8 @@ def main(): | |||
493 | mixed_precision=args.mixed_precision | 492 | mixed_precision=args.mixed_precision |
494 | ) | 493 | ) |
495 | 494 | ||
495 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
496 | |||
496 | # If passed along, set the training seed now. | 497 | # If passed along, set the training seed now. |
497 | if args.seed is not None: | 498 | if args.seed is not None: |
498 | set_seed(args.seed) | 499 | set_seed(args.seed) |
@@ -655,8 +656,7 @@ def main(): | |||
655 | 656 | ||
656 | # Scheduler and math around the number of training steps. | 657 | # Scheduler and math around the number of training steps. |
657 | overrode_max_train_steps = False | 658 | overrode_max_train_steps = False |
658 | num_update_steps_per_epoch = math.ceil( | 659 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
659 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
660 | if args.max_train_steps is None: | 660 | if args.max_train_steps is None: |
661 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 661 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
662 | overrode_max_train_steps = True | 662 | overrode_max_train_steps = True |
@@ -681,13 +681,12 @@ def main(): | |||
681 | vae.eval() | 681 | vae.eval() |
682 | 682 | ||
683 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 683 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
684 | num_update_steps_per_epoch = math.ceil( | 684 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
685 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
686 | if overrode_max_train_steps: | 685 | if overrode_max_train_steps: |
687 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 686 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
688 | # Afterwards we recalculate our number of training epochs | 687 | |
689 | args.num_train_epochs = math.ceil( | 688 | num_val_steps_per_epoch = len(val_dataloader) |
690 | args.max_train_steps / num_update_steps_per_epoch) | 689 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
691 | 690 | ||
692 | # We need to initialize the trackers we use, and also store our configuration. | 691 | # We need to initialize the trackers we use, and also store our configuration. |
693 | # The trackers initializes automatically on the main process. | 692 | # The trackers initializes automatically on the main process. |
@@ -698,7 +697,7 @@ def main(): | |||
698 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 697 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
699 | 698 | ||
700 | logger.info("***** Running training *****") | 699 | logger.info("***** Running training *****") |
701 | logger.info(f" Num Epochs = {args.num_train_epochs}") | 700 | logger.info(f" Num Epochs = {num_epochs}") |
702 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 701 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
703 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 702 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
704 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 703 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
@@ -713,14 +712,16 @@ def main(): | |||
713 | 0, | 712 | 0, |
714 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 713 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
715 | 714 | ||
716 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 715 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
717 | local_progress_bar.set_description("Steps ") | 716 | disable=not accelerator.is_local_main_process) |
717 | local_progress_bar.set_description("Batch X out of Y") | ||
718 | 718 | ||
719 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 719 | global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
720 | progress_bar.set_description("Global steps") | 720 | global_progress_bar.set_description("Total progress") |
721 | 721 | ||
722 | try: | 722 | try: |
723 | for epoch in range(args.num_train_epochs): | 723 | for epoch in range(num_epochs): |
724 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | ||
724 | local_progress_bar.reset() | 725 | local_progress_bar.reset() |
725 | 726 | ||
726 | unet.train() | 727 | unet.train() |
@@ -784,13 +785,13 @@ def main(): | |||
784 | # Checks if the accelerator has performed an optimization step behind the scenes | 785 | # Checks if the accelerator has performed an optimization step behind the scenes |
785 | if accelerator.sync_gradients: | 786 | if accelerator.sync_gradients: |
786 | local_progress_bar.update(1) | 787 | local_progress_bar.update(1) |
787 | progress_bar.update(1) | 788 | global_progress_bar.update(1) |
788 | 789 | ||
789 | global_step += 1 | 790 | global_step += 1 |
790 | 791 | ||
791 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 792 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
792 | local_progress_bar.clear() | 793 | local_progress_bar.clear() |
793 | progress_bar.clear() | 794 | global_progress_bar.clear() |
794 | 795 | ||
795 | checkpointer.save_samples( | 796 | checkpointer.save_samples( |
796 | "training", | 797 | "training", |
@@ -846,7 +847,6 @@ def main(): | |||
846 | 847 | ||
847 | if accelerator.sync_gradients: | 848 | if accelerator.sync_gradients: |
848 | local_progress_bar.update(1) | 849 | local_progress_bar.update(1) |
849 | progress_bar.update(1) | ||
850 | 850 | ||
851 | logs = {"mode": "validation", "loss": loss} | 851 | logs = {"mode": "validation", "loss": loss} |
852 | local_progress_bar.set_postfix(**logs) | 852 | local_progress_bar.set_postfix(**logs) |
@@ -856,7 +856,7 @@ def main(): | |||
856 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 856 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
857 | 857 | ||
858 | local_progress_bar.clear() | 858 | local_progress_bar.clear() |
859 | progress_bar.clear() | 859 | global_progress_bar.clear() |
860 | 860 | ||
861 | if min_val_loss > val_loss: | 861 | if min_val_loss > val_loss: |
862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |