diff options
| -rw-r--r-- | dreambooth.py | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7c0d32c..170b8e9 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -2,6 +2,7 @@ import argparse | |||
| 2 | import math | 2 | import math |
| 3 | import os | 3 | import os |
| 4 | import datetime | 4 | import datetime |
| 5 | import logging | ||
| 5 | from pathlib import Path | 6 | from pathlib import Path |
| 6 | 7 | ||
| 7 | import numpy as np | 8 | import numpy as np |
| @@ -20,7 +21,6 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 22 | from slugify import slugify |
| 22 | import json | 23 | import json |
| 23 | import os | ||
| 24 | 24 | ||
| 25 | from data.dreambooth.csv import CSVDataModule | 25 | from data.dreambooth.csv import CSVDataModule |
| 26 | from data.dreambooth.prompt import PromptDataset | 26 | from data.dreambooth.prompt import PromptDataset |
| @@ -95,7 +95,7 @@ def parse_args(): | |||
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--max_train_steps", | 96 | "--max_train_steps", |
| 97 | type=int, | 97 | type=int, |
| 98 | default=1000, | 98 | default=600, |
| 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 100 | ) | 100 | ) |
| 101 | parser.add_argument( | 101 | parser.add_argument( |
| @@ -380,8 +380,8 @@ class Checkpointer: | |||
| 380 | 380 | ||
| 381 | @torch.no_grad() | 381 | @torch.no_grad() |
| 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): |
| 383 | samples_path = f"{self.output_dir}/samples/{mode}" | 383 | samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) |
| 384 | os.makedirs(samples_path, exist_ok=True) | 384 | samples_path.mkdir(parents=True, exist_ok=True) |
| 385 | 385 | ||
| 386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 386 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 387 | pipeline = StableDiffusionPipeline( | 387 | pipeline = StableDiffusionPipeline( |
| @@ -471,7 +471,6 @@ class Checkpointer: | |||
| 471 | del all_samples | 471 | del all_samples |
| 472 | del image_grid | 472 | del image_grid |
| 473 | 473 | ||
| 474 | del checker | ||
| 475 | del unwrapped | 474 | del unwrapped |
| 476 | del pipeline | 475 | del pipeline |
| 477 | 476 | ||
| @@ -483,8 +482,8 @@ def main(): | |||
| 483 | args = parse_args() | 482 | args = parse_args() |
| 484 | 483 | ||
| 485 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 484 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 486 | basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" | 485 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) |
| 487 | os.makedirs(basepath, exist_ok=True) | 486 | basepath.mkdir(parents=True, exist_ok=True) |
| 488 | 487 | ||
| 489 | accelerator = Accelerator( | 488 | accelerator = Accelerator( |
| 490 | log_with=LoggerType.TENSORBOARD, | 489 | log_with=LoggerType.TENSORBOARD, |
| @@ -493,6 +492,8 @@ def main(): | |||
| 493 | mixed_precision=args.mixed_precision | 492 | mixed_precision=args.mixed_precision |
| 494 | ) | 493 | ) |
| 495 | 494 | ||
| 495 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
| 496 | |||
| 496 | # If passed along, set the training seed now. | 497 | # If passed along, set the training seed now. |
| 497 | if args.seed is not None: | 498 | if args.seed is not None: |
| 498 | set_seed(args.seed) | 499 | set_seed(args.seed) |
| @@ -655,8 +656,7 @@ def main(): | |||
| 655 | 656 | ||
| 656 | # Scheduler and math around the number of training steps. | 657 | # Scheduler and math around the number of training steps. |
| 657 | overrode_max_train_steps = False | 658 | overrode_max_train_steps = False |
| 658 | num_update_steps_per_epoch = math.ceil( | 659 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 659 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
| 660 | if args.max_train_steps is None: | 660 | if args.max_train_steps is None: |
| 661 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 661 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 662 | overrode_max_train_steps = True | 662 | overrode_max_train_steps = True |
| @@ -681,13 +681,12 @@ def main(): | |||
| 681 | vae.eval() | 681 | vae.eval() |
| 682 | 682 | ||
| 683 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 683 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 684 | num_update_steps_per_epoch = math.ceil( | 684 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 685 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
| 686 | if overrode_max_train_steps: | 685 | if overrode_max_train_steps: |
| 687 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 686 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 688 | # Afterwards we recalculate our number of training epochs | 687 | |
| 689 | args.num_train_epochs = math.ceil( | 688 | num_val_steps_per_epoch = len(val_dataloader) |
| 690 | args.max_train_steps / num_update_steps_per_epoch) | 689 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 691 | 690 | ||
| 692 | # We need to initialize the trackers we use, and also store our configuration. | 691 | # We need to initialize the trackers we use, and also store our configuration. |
| 693 | # The trackers initializes automatically on the main process. | 692 | # The trackers initializes automatically on the main process. |
| @@ -698,7 +697,7 @@ def main(): | |||
| 698 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 697 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| 699 | 698 | ||
| 700 | logger.info("***** Running training *****") | 699 | logger.info("***** Running training *****") |
| 701 | logger.info(f" Num Epochs = {args.num_train_epochs}") | 700 | logger.info(f" Num Epochs = {num_epochs}") |
| 702 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 701 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| 703 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 702 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| 704 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 703 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| @@ -713,14 +712,16 @@ def main(): | |||
| 713 | 0, | 712 | 0, |
| 714 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 713 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 715 | 714 | ||
| 716 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 715 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 717 | local_progress_bar.set_description("Steps ") | 716 | disable=not accelerator.is_local_main_process) |
| 717 | local_progress_bar.set_description("Batch X out of Y") | ||
| 718 | 718 | ||
| 719 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 719 | global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
| 720 | progress_bar.set_description("Global steps") | 720 | global_progress_bar.set_description("Total progress") |
| 721 | 721 | ||
| 722 | try: | 722 | try: |
| 723 | for epoch in range(args.num_train_epochs): | 723 | for epoch in range(num_epochs): |
| 724 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | ||
| 724 | local_progress_bar.reset() | 725 | local_progress_bar.reset() |
| 725 | 726 | ||
| 726 | unet.train() | 727 | unet.train() |
| @@ -784,13 +785,13 @@ def main(): | |||
| 784 | # Checks if the accelerator has performed an optimization step behind the scenes | 785 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 785 | if accelerator.sync_gradients: | 786 | if accelerator.sync_gradients: |
| 786 | local_progress_bar.update(1) | 787 | local_progress_bar.update(1) |
| 787 | progress_bar.update(1) | 788 | global_progress_bar.update(1) |
| 788 | 789 | ||
| 789 | global_step += 1 | 790 | global_step += 1 |
| 790 | 791 | ||
| 791 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 792 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
| 792 | local_progress_bar.clear() | 793 | local_progress_bar.clear() |
| 793 | progress_bar.clear() | 794 | global_progress_bar.clear() |
| 794 | 795 | ||
| 795 | checkpointer.save_samples( | 796 | checkpointer.save_samples( |
| 796 | "training", | 797 | "training", |
| @@ -846,7 +847,6 @@ def main(): | |||
| 846 | 847 | ||
| 847 | if accelerator.sync_gradients: | 848 | if accelerator.sync_gradients: |
| 848 | local_progress_bar.update(1) | 849 | local_progress_bar.update(1) |
| 849 | progress_bar.update(1) | ||
| 850 | 850 | ||
| 851 | logs = {"mode": "validation", "loss": loss} | 851 | logs = {"mode": "validation", "loss": loss} |
| 852 | local_progress_bar.set_postfix(**logs) | 852 | local_progress_bar.set_postfix(**logs) |
| @@ -856,7 +856,7 @@ def main(): | |||
| 856 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 856 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
| 857 | 857 | ||
| 858 | local_progress_bar.clear() | 858 | local_progress_bar.clear() |
| 859 | progress_bar.clear() | 859 | global_progress_bar.clear() |
| 860 | 860 | ||
| 861 | if min_val_loss > val_loss: | 861 | if min_val_loss > val_loss: |
| 862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 862 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
