From 72b36be1bd8f45408830646efe2c7309d7dbfe33 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 11:49:34 +0200 Subject: Dreamooth script: Fixed step calculation, added file log --- dreambooth.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 7c0d32c..170b8e9 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -2,6 +2,7 @@ import argparse import math import os import datetime +import logging from pathlib import Path import numpy as np @@ -20,7 +21,6 @@ from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify import json -import os from data.dreambooth.csv import CSVDataModule from data.dreambooth.prompt import PromptDataset @@ -95,7 +95,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1000, + default=600, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -380,8 +380,8 @@ class Checkpointer: @torch.no_grad() def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): - samples_path = f"{self.output_dir}/samples/{mode}" - os.makedirs(samples_path, exist_ok=True) + samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) + samples_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( @@ -471,7 +471,6 @@ class Checkpointer: del all_samples del image_grid - del checker del unwrapped del pipeline @@ -483,8 +482,8 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}" - os.makedirs(basepath, exist_ok=True) + basepath = Path(args.output_dir).joinpath(slugify(args.identifier)).joinpath(now) + basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, @@ -493,6 +492,8 @@ def main(): mixed_precision=args.mixed_precision ) + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -655,8 +656,7 @@ def main(): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -681,13 +681,12 @@ def main(): vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch) + + num_val_steps_per_epoch = len(val_dataloader) + num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -698,7 +697,7 @@ def main(): total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") - logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") @@ -713,14 +712,16 @@ def main(): 0, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) - local_progress_bar.set_description("Steps ") + local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process) + local_progress_bar.set_description("Batch X out of Y") - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Global steps") + global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + global_progress_bar.set_description("Total progress") try: - for epoch in range(args.num_train_epochs): + for epoch in range(num_epochs): + local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") local_progress_bar.reset() unet.train() @@ -784,13 +785,13 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: local_progress_bar.update(1) - progress_bar.update(1) + global_progress_bar.update(1) global_step += 1 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: local_progress_bar.clear() - progress_bar.clear() + global_progress_bar.clear() checkpointer.save_samples( "training", @@ -846,7 +847,6 @@ def main(): if accelerator.sync_gradients: local_progress_bar.update(1) - progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) @@ -856,7 +856,7 @@ def main(): accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) local_progress_bar.clear() - progress_bar.clear() + global_progress_bar.clear() if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") -- cgit v1.2.3-70-g09d2