From a8a5abae42f6f42056cc27e0cf5313aab080c3a7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 22:16:13 +0200 Subject: Various improvements, added inference script --- .gitignore | 1 + data/dreambooth/csv.py | 1 + dreambooth.py | 98 ++++++++++++++++++++------------------- infer.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 175 insertions(+), 46 deletions(-) create mode 100644 infer.py diff --git a/.gitignore b/.gitignore index 91a5e07..218c628 100644 --- a/.gitignore +++ b/.gitignore @@ -161,5 +161,6 @@ cython_debug/ text-inversion-model/ dreambooth-model/ +inference/ conf*.json v1-inference.yaml* diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index e70c068..14c13bb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -117,6 +117,7 @@ class CSVDataset(Dataset): [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] diff --git a/dreambooth.py b/dreambooth.py index c01cbe3..bc7a472 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -1,5 +1,4 @@ import argparse -import itertools import math import os import datetime @@ -61,7 +60,8 @@ def parse_args(): "--repeats", type=int, default=100, - help="How many times to repeat the training data.") + help="How many times to repeat the training data." + ) parser.add_argument( "--output_dir", type=str, @@ -72,7 +72,8 @@ def parse_args(): "--seed", type=int, default=None, - help="A seed for reproducible training.") + help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -94,7 +95,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=1000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -184,7 +185,7 @@ def parse_args(): parser.add_argument( "--checkpoint_frequency", type=int, - default=500, + default=200, help="How often to save a checkpoint and sample image", ) parser.add_argument( @@ -220,7 +221,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=50, + default=80, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -381,7 +382,6 @@ class Checkpointer: def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = f"{self.output_dir}/samples/{mode}" os.makedirs(samples_path, exist_ok=True) - checker = NoCheck() unwrapped = self.accelerator.unwrap_model(self.unet) pipeline = StableDiffusionPipeline( @@ -507,6 +507,7 @@ def main(): torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype) + pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images @@ -589,7 +590,11 @@ def main(): # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + tensor_format="pt" ) def collate_fn(examples): @@ -709,7 +714,7 @@ def main(): args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) - local_progress_bar.set_description("Steps") + local_progress_bar.set_description("Steps ") progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Global steps") @@ -723,31 +728,31 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - with accelerator.autocast(): - # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): if args.with_prior_preservation: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) @@ -766,12 +771,12 @@ def main(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) - + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() train_loss += loss @@ -804,7 +809,7 @@ def main(): val_loss = 0.0 for step, batch in enumerate(val_dataloader): - with torch.no_grad(), accelerator.autocast(): + with torch.no_grad(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -822,18 +827,19 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - if args.with_prior_preservation: - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + with accelerator.autocast(): + if args.with_prior_preservation: + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, + reduction="none").mean([1, 2, 3]).mean() - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() loss = loss.detach().item() val_loss += loss diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b9e9ff7 --- /dev/null +++ b/infer.py @@ -0,0 +1,121 @@ +import argparse +import datetime +from pathlib import Path +from torch import autocast +from diffusers import StableDiffusionPipeline +import torch +import json +from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler +from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor +from slugify import slugify +from pipelines.stable_diffusion.no_check import NoCheck + +model_id = "path-to-your-trained-model" + +prompt = "A photo of sks dog in a bucket" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--model", + type=str, + default=None, + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + ) + parser.add_argument( + "--batch_num", + type=int, + default=50, + ) + parser.add_argument( + "--steps", + type=int, + default=80, + ) + parser.add_argument( + "--scale", + type=int, + default=7.5, + ) + parser.add_argument( + "--seed", + type=int, + default=None, + ) + parser.add_argument( + "--output_dir", + type=str, + default="inference", + ) + parser.add_argument( + "--config", + type=str, + default=None, + ) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + return args + + +def main(): + args = parse_args() + + seed = args.seed or torch.random.seed() + generator = torch.Generator(device="cuda").manual_seed(seed) + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") + output_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) + text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) + vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) + unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) + feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) + + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=NoCheck(), + feature_extractor=feature_extractor + ) + pipeline.enable_attention_slicing() + pipeline.to("cuda") + + with autocast("cuda"): + for i in range(args.batch_num): + images = pipeline( + [args.prompt] * args.batch_size, + num_inference_steps=args.steps, + guidance_scale=args.scale, + generator=generator, + ).images + + for j, image in enumerate(images): + image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-70-g09d2