From 306f2bfb620e6882737658bd3694c79365d75e4b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 15:23:40 +0200 Subject: Improved prompt handling --- data/csv.py | 83 ++++++++++------------ dreambooth_plus.py | 28 +++----- infer.py | 15 ++-- models/clip/prompt.py | 31 ++++++++ .../stable_diffusion/vlpn_stable_diffusion.py | 72 +++++-------------- 5 files changed, 100 insertions(+), 129 deletions(-) create mode 100644 models/clip/prompt.py diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,11 +1,14 @@ import math import pandas as pd +import torch from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import NamedTuple, List +from typing import NamedTuple, List, Optional + +from models.clip.prompt import PromptProcessor class CSVDataItem(NamedTuple): @@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): class CSVDataModule(pl.LightningDataModule): def __init__( self, - batch_size, - data_file, - tokenizer, - instance_identifier, - class_identifier=None, - class_subdir="cls", - num_class_images=100, - size=512, - repeats=100, - interpolation="bicubic", - center_crop=False, - valid_set_size=None, - generator=None, + batch_size: int, + data_file: str, + prompt_processor: PromptProcessor, + instance_identifier: str, + class_identifier: Optional[str] = None, + class_subdir: str = "cls", + num_class_images: int = 100, + size: int = 512, + repeats: int = 1, + interpolation: str = "bicubic", + center_crop: bool = False, + valid_set_size: Optional[int] = None, + generator: Optional[torch.Generator] = None, collate_fn=None ): super().__init__() @@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images - self.tokenizer = tokenizer + self.prompt_processor = prompt_processor self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.size = size @@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root.joinpath(item.image), self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), item.prompt, - item.nprompt if "nprompt" in item else "" + item.nprompt ) for item in data for i in range(image_multiplier) @@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): self.data_val = self.prepare_subdata(data_val) def setup(self, stage=None): - train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, + train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) - val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, + val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) @@ -113,19 +116,19 @@ class CSVDataset(Dataset): def __init__( self, data: List[CSVDataItem], - tokenizer, - instance_identifier, - batch_size=1, - class_identifier=None, - num_class_images=0, - size=512, - repeats=1, - interpolation="bicubic", - center_crop=False, + prompt_processor: PromptProcessor, + instance_identifier: str, + batch_size: int = 1, + class_identifier: Optional[str] = None, + num_class_images: int = 0, + size: int = 512, + repeats: int = 1, + interpolation: str = "bicubic", + center_crop: bool = False, ): self.data = data - self.tokenizer = tokenizer + self.prompt_processor = prompt_processor self.batch_size = batch_size self.instance_identifier = instance_identifier self.class_identifier = class_identifier @@ -163,12 +166,6 @@ class CSVDataset(Dataset): example = {} - if isinstance(item.prompt, str): - item.prompt = [item.prompt] - - if isinstance(item.nprompt, str): - item.nprompt = [item.nprompt] - example["prompts"] = item.prompt example["nprompts"] = item.nprompt @@ -181,12 +178,9 @@ class CSVDataset(Dataset): self.image_cache[item.instance_image_path] = instance_image example["instance_images"] = instance_image - example["instance_prompt_ids"] = self.tokenizer( - item.prompt.format(self.instance_identifier), - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + item.prompt.format(self.instance_identifier) + ) if self.num_class_images != 0: class_image = Image.open(item.class_image_path) @@ -194,12 +188,9 @@ class CSVDataset(Dataset): class_image = class_image.convert("RGB") example["class_images"] = class_image - example["class_prompt_ids"] = self.tokenizer( - item.prompt.format(self.class_identifier), - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids + example["class_prompt_ids"] = self.prompt_processor.get_input_ids( + item.nprompt.format(self.class_identifier) + ) self.cache[item.instance_image_path] = example return example diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -26,6 +26,7 @@ from slugify import slugify from schedulers.scheduling_euler_a import EulerAScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -147,7 +148,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=1e-6, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -470,7 +471,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] prompt = [ - [p.format(self.instance_identifier) for p in prompt] + prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"] ][:self.sample_batch_size] @@ -573,6 +574,8 @@ def main(): device=accelerator.device ) if args.use_ema else None + prompt_processor = PromptProcessor(tokenizer, text_encoder) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -663,7 +666,7 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, @@ -673,21 +676,10 @@ def main(): } return batch - def encode_input_ids(input_ids): - text_embeddings = [] - - for ids in input_ids: - embeddings = text_encoder(ids)[0] - embeddings = embeddings.reshape((1, -1, 768)) - text_embeddings.append(embeddings) - - text_embeddings = torch.cat(text_embeddings) - return text_embeddings - datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - tokenizer=tokenizer, + prompt_processor=prompt_processor, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", @@ -727,7 +719,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] + prompt = [item.prompt.format(args.class_identifier) for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -875,7 +867,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = encode_input_ids(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -974,7 +966,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = encode_input_ids(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample diff --git a/infer.py b/infer.py index d744768..8e17c4e 100644 --- a/infer.py +++ b/infer.py @@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True -line_sep = " " - - default_args = { "model": None, "scheduler": "euler_a", @@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): def generate(output_dir, pipeline, args): + if isinstance(args.prompt, str): + args.prompt = [args.prompt] + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") + output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) seed = args.seed or torch.random.seed() @@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args): dynamic_ncols=True ) - if isinstance(args.prompt, str): - args.prompt = [args.prompt] - - prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size - generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( - prompt=prompt, + prompt=args.prompt * (args.batch_size // len(args.prompt)), height=args.height, width=args.width, negative_prompt=args.negative_prompt, diff --git a/models/clip/prompt.py b/models/clip/prompt.py new file mode 100644 index 0000000..c1e3340 --- /dev/null +++ b/models/clip/prompt.py @@ -0,0 +1,31 @@ +from typing import List, Optional, Union + +import torch + +from transformers import CLIPTokenizer, CLIPTextModel + + +class PromptProcessor(): + def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + + def get_input_ids(self, prompt: Union[str, List[str]]): + return self.tokenizer( + prompt, + padding="do_not_pad", + ).input_ids + + def unify_input_ids(self, input_ids: List[int]): + return self.tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + pad_to_multiple_of=self.tokenizer.model_max_length, + return_tensors="pt" + ).input_ids + + def get_embeddings(self, input_ids: torch.IntTensor): + prompts = input_ids.shape[0] + input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) + return text_embeddings diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b68b028..3da0169 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel +from transformers import CLIPTextModel, CLIPTokenizer from schedulers.scheduling_euler_a import EulerAScheduler +from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -24,22 +25,6 @@ def preprocess(image, w, h): return 2.0 * image - 1.0 -def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): - if isinstance(prompt, str): - prompt = [prompt] * batch_size - - if isinstance(prompt, list) and isinstance(prompt[0], str): - prompt = [[p] for p in prompt] - - if isinstance(prompt, list) and isinstance(prompt[0], list): - prompt_size = prompt_size or max([len(p) for p in prompt]) - prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - return prompt_size, prompt - - class VlpnStableDiffusion(DiffusionPipeline): def __init__( self, @@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + self.prompt_processor = PromptProcessor(tokenizer, text_encoder) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def embeddings_for_prompt(self, prompt: List[List[str]]): - text_embeddings = [] - - for p in prompt: - inputs = self.tokenizer( - p, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - input_ids = inputs.input_ids - - if input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - print(f"Too many tokens: {removed_text}") - input_ids = input_ids[:, : self.tokenizer.model_max_length] - - embeddings = self.text_encoder(input_ids.to(self.device))[0] - embeddings = embeddings.reshape((1, -1, 768)) - text_embeddings.append(embeddings) - - text_embeddings = torch.cat(text_embeddings) - return text_embeddings - @torch.no_grad() def __call__( self, @@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ - prompt_size, prompt = normalize_prompt(prompt) + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) - _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) - if len(negative_prompt) != batch_size: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + if len(negative_prompt) != len(prompt): raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # get prompt text embeddings - text_embeddings = self.embeddings_for_prompt(prompt) + text_input_ids = self.prompt_processor.get_input_ids(prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_embeddings = self.embeddings_for_prompt(negative_prompt) + unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) + text_input_ids = unconditional_input_ids + text_input_ids - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) + text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) offset = self.scheduler.config.get("steps_offset", 0) init_timestep = num_inference_steps + offset -- cgit v1.2.3-54-g00ecf