From 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Oct 2022 21:56:54 +0200 Subject: Update --- dreambooth.py | 75 +++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 30 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -3,6 +3,7 @@ import math import os import datetime import logging +import json from pathlib import Path import numpy as np @@ -21,7 +22,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -import json from data.csv import CSVDataModule @@ -68,7 +68,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=4, + default=200, help="How many class images to generate per training image." ) parser.add_argument( @@ -140,7 +140,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="constant", + default="linear", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -198,6 +198,12 @@ def parse_args(): default=-1, help="For distributed training: local_rank" ) + parser.add_argument( + "--sample_frequency", + type=int, + default=100, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_image_size", type=int, @@ -366,20 +372,20 @@ class Checkpointer: generator=generator, ) - for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: - all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: + all_samples = [] + file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path.parent.mkdir(parents=True, exist_ok=True) - data_enum = enumerate(data) + data_enum = enumerate(data) - for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt.format(self.instance_identifier) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, negative_prompt=nprompt, @@ -393,15 +399,15 @@ class Checkpointer: output_type='pil' )["sample"] - all_samples += samples + all_samples += samples - del samples + del samples - image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) + image_grid.save(file_path) - del all_samples - del image_grid + del all_samples + del image_grid del unwrapped del scheduler @@ -538,7 +544,7 @@ def main(): datamodule.setup() if args.num_class_images != 0: - missing_data = [item for item in datamodule.data if not item[1].exists()] + missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: batched_data = [missing_data[i:i+args.sample_batch_size] @@ -558,20 +564,20 @@ def main(): pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) - for batch in batched_data: - image_name = [p[1] for p in batch] - prompt = [p[2].format(args.class_identifier) for p in batch] - nprompt = [p[3] for p in batch] + with torch.inference_mode(): + for batch in batched_data: + image_name = [p.class_image_path for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] + nprompt = [p.nprompt for p in batch] - with accelerator.autocast(): images = pipeline( prompt=prompt, negative_prompt=nprompt, num_inference_steps=args.sample_steps ).images - for i, image in enumerate(images): - image.save(image_name[i]) + for i, image in enumerate(images): + image.save(image_name[i]) del pipeline @@ -677,6 +683,8 @@ def main(): unet.train() train_loss = 0.0 + sample_checkpoint = False + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space @@ -737,6 +745,9 @@ def main(): global_step += 1 + if global_step % args.sample_frequency == 0: + sample_checkpoint = True + logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) @@ -783,7 +794,11 @@ def main(): val_loss /= len(val_dataloader) - accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) + accelerator.log({ + "train/loss": train_loss, + "val/loss": val_loss, + "lr": lr_scheduler.get_last_lr()[0] + }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() @@ -792,7 +807,7 @@ def main(): accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss - if accelerator.is_main_process: + if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( global_step, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) -- cgit v1.2.3-54-g00ecf