From 728dfcf57c30f40236b3a00d7380c4e0057cacb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 22:08:58 +0200 Subject: Implemented extended prompt limit --- data/csv.py | 14 +++- dreambooth_plus.py | 44 +++++++--- infer.py | 12 ++- .../stable_diffusion/vlpn_stable_diffusion.py | 96 +++++++++++++--------- 4 files changed, 108 insertions(+), 58 deletions(-) diff --git a/data/csv.py b/data/csv.py index aad970c..316c099 100644 --- a/data/csv.py +++ b/data/csv.py @@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule): ] def prepare_data(self): - metadata = pd.read_csv(self.data_file) - metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] + metadata = pd.read_json(self.data_file) + metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] num_images = len(metadata) valid_set_size = int(num_images * 0.2) @@ -163,6 +163,12 @@ class CSVDataset(Dataset): example = {} + if isinstance(item.prompt, str): + item.prompt = [item.prompt] + + if isinstance(item.nprompt, str): + item.nprompt = [item.nprompt] + example["prompts"] = item.prompt example["nprompts"] = item.nprompt @@ -177,7 +183,7 @@ class CSVDataset(Dataset): example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( item.prompt.format(self.instance_identifier), - padding="do_not_pad", + padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids @@ -190,7 +196,7 @@ class CSVDataset(Dataset): example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( item.prompt.format(self.class_identifier), - padding="do_not_pad", + padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids diff --git a/dreambooth_plus.py b/dreambooth_plus.py index a98417f..ae31377 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -124,7 +124,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1500, + default=1400, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -147,7 +147,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -469,9 +469,16 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + prompt = [ + [p.format(self.instance_identifier) for p in prompt] + for batch in batches + for prompt in batch["prompts"] + ][:self.sample_batch_size] + nprompt = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ][:self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -666,6 +673,17 @@ def main(): } return batch + def encode_input_ids(input_ids): + text_embeddings = [] + + for ids in input_ids: + embeddings = text_encoder(ids)[0] + embeddings = embeddings.reshape((1, -1, 768)) + text_embeddings.append(embeddings) + + text_embeddings = torch.cat(text_embeddings) + return text_embeddings + datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -688,8 +706,10 @@ def main(): missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: - batched_data = [missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size)] + batched_data = [ + missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size) + ] scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" @@ -706,9 +726,9 @@ def main(): with torch.inference_mode(): for batch in batched_data: - image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.class_identifier) for p in batch] - nprompt = [p.nprompt for p in batch] + image_name = [item.class_image_path for item in batch] + prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] + nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, @@ -855,7 +875,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = encode_input_ids(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -954,7 +974,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = encode_input_ids(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample diff --git a/infer.py b/infer.py index 1a0baf5..d744768 100644 --- a/infer.py +++ b/infer.py @@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True +line_sep = " " + + default_args = { "model": None, "scheduler": "euler_a", @@ -95,10 +98,12 @@ def create_cmd_parser(): parser.add_argument( "--prompt", type=str, + nargs="+", ) parser.add_argument( "--negative_prompt", type=str, + nargs="*", ) parser.add_argument( "--image", @@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args): dynamic_ncols=True ) + if isinstance(args.prompt, str): + args.prompt = [args.prompt] + + prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size + generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( - prompt=[args.prompt] * args.batch_size, + prompt=prompt, height=args.height, width=args.width, negative_prompt=args.negative_prompt, diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8b08a6f..b68b028 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -24,6 +24,22 @@ def preprocess(image, w, h): return 2.0 * image - 1.0 +def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): + if isinstance(prompt, str): + prompt = [prompt] * batch_size + + if isinstance(prompt, list) and isinstance(prompt[0], str): + prompt = [[p] for p in prompt] + + if isinstance(prompt, list) and isinstance(prompt[0], list): + prompt_size = prompt_size or max([len(p) for p in prompt]) + prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + return prompt_size, prompt + + class VlpnStableDiffusion(DiffusionPipeline): def __init__( self, @@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def embeddings_for_prompt(self, prompt: List[List[str]]): + text_embeddings = [] + + for p in prompt: + inputs = self.tokenizer( + p, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = inputs.input_ids + + if input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + print(f"Too many tokens: {removed_text}") + input_ids = input_ids[:, : self.tokenizer.model_max_length] + + embeddings = self.text_encoder(input_ids.to(self.device))[0] + embeddings = embeddings.reshape((1, -1, 768)) + text_embeddings.append(embeddings) + + text_embeddings = torch.cat(text_embeddings) + return text_embeddings + @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, + prompt: Union[str, List[str], List[List[str]]], + negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, strength: float = 0.8, height: Optional[int] = 512, width: Optional[int] = 512, @@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - elif isinstance(negative_prompt, list): - if len(negative_prompt) != batch_size: - raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") - else: - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + prompt_size, prompt = normalize_prompt(prompt) + batch_size = len(prompt) + _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) + + if len(negative_prompt) != batch_size: + raise ValueError( + f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - print(f"Too many tokens: {removed_text}") - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.embeddings_for_prompt(prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.embeddings_for_prompt(negative_prompt) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch -- cgit v1.2.3-54-g00ecf