diff options
| -rw-r--r-- | data/csv.py | 83 | ||||
| -rw-r--r-- | dreambooth_plus.py | 28 | ||||
| -rw-r--r-- | infer.py | 15 | ||||
| -rw-r--r-- | models/clip/prompt.py | 31 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 72 |
5 files changed, 100 insertions, 129 deletions
diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,11 +1,14 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import pandas as pd | 2 | import pandas as pd |
| 3 | import torch | ||
| 3 | from pathlib import Path | 4 | from pathlib import Path |
| 4 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
| 5 | from PIL import Image | 6 | from PIL import Image |
| 6 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 7 | from torchvision import transforms | 8 | from torchvision import transforms |
| 8 | from typing import NamedTuple, List | 9 | from typing import NamedTuple, List, Optional |
| 10 | |||
| 11 | from models.clip.prompt import PromptProcessor | ||
| 9 | 12 | ||
| 10 | 13 | ||
| 11 | class CSVDataItem(NamedTuple): | 14 | class CSVDataItem(NamedTuple): |
| @@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): | |||
| 18 | class CSVDataModule(pl.LightningDataModule): | 21 | class CSVDataModule(pl.LightningDataModule): |
| 19 | def __init__( | 22 | def __init__( |
| 20 | self, | 23 | self, |
| 21 | batch_size, | 24 | batch_size: int, |
| 22 | data_file, | 25 | data_file: str, |
| 23 | tokenizer, | 26 | prompt_processor: PromptProcessor, |
| 24 | instance_identifier, | 27 | instance_identifier: str, |
| 25 | class_identifier=None, | 28 | class_identifier: Optional[str] = None, |
| 26 | class_subdir="cls", | 29 | class_subdir: str = "cls", |
| 27 | num_class_images=100, | 30 | num_class_images: int = 100, |
| 28 | size=512, | 31 | size: int = 512, |
| 29 | repeats=100, | 32 | repeats: int = 1, |
| 30 | interpolation="bicubic", | 33 | interpolation: str = "bicubic", |
| 31 | center_crop=False, | 34 | center_crop: bool = False, |
| 32 | valid_set_size=None, | 35 | valid_set_size: Optional[int] = None, |
| 33 | generator=None, | 36 | generator: Optional[torch.Generator] = None, |
| 34 | collate_fn=None | 37 | collate_fn=None |
| 35 | ): | 38 | ): |
| 36 | super().__init__() | 39 | super().__init__() |
| @@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 45 | self.class_root.mkdir(parents=True, exist_ok=True) | 48 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 46 | self.num_class_images = num_class_images | 49 | self.num_class_images = num_class_images |
| 47 | 50 | ||
| 48 | self.tokenizer = tokenizer | 51 | self.prompt_processor = prompt_processor |
| 49 | self.instance_identifier = instance_identifier | 52 | self.instance_identifier = instance_identifier |
| 50 | self.class_identifier = class_identifier | 53 | self.class_identifier = class_identifier |
| 51 | self.size = size | 54 | self.size = size |
| @@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 65 | self.data_root.joinpath(item.image), | 68 | self.data_root.joinpath(item.image), |
| 66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), |
| 67 | item.prompt, | 70 | item.prompt, |
| 68 | item.nprompt if "nprompt" in item else "" | 71 | item.nprompt |
| 69 | ) | 72 | ) |
| 70 | for item in data | 73 | for item in data |
| 71 | for i in range(image_multiplier) | 74 | for i in range(image_multiplier) |
| @@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 88 | self.data_val = self.prepare_subdata(data_val) | 91 | self.data_val = self.prepare_subdata(data_val) |
| 89 | 92 | ||
| 90 | def setup(self, stage=None): | 93 | def setup(self, stage=None): |
| 91 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| 92 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 95 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 93 | num_class_images=self.num_class_images, | 96 | num_class_images=self.num_class_images, |
| 94 | size=self.size, interpolation=self.interpolation, | 97 | size=self.size, interpolation=self.interpolation, |
| 95 | center_crop=self.center_crop, repeats=self.repeats) | 98 | center_crop=self.center_crop, repeats=self.repeats) |
| 96 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 97 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
| 98 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
| 99 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop, repeats=self.repeats) |
| @@ -113,19 +116,19 @@ class CSVDataset(Dataset): | |||
| 113 | def __init__( | 116 | def __init__( |
| 114 | self, | 117 | self, |
| 115 | data: List[CSVDataItem], | 118 | data: List[CSVDataItem], |
| 116 | tokenizer, | 119 | prompt_processor: PromptProcessor, |
| 117 | instance_identifier, | 120 | instance_identifier: str, |
| 118 | batch_size=1, | 121 | batch_size: int = 1, |
| 119 | class_identifier=None, | 122 | class_identifier: Optional[str] = None, |
| 120 | num_class_images=0, | 123 | num_class_images: int = 0, |
| 121 | size=512, | 124 | size: int = 512, |
| 122 | repeats=1, | 125 | repeats: int = 1, |
| 123 | interpolation="bicubic", | 126 | interpolation: str = "bicubic", |
| 124 | center_crop=False, | 127 | center_crop: bool = False, |
| 125 | ): | 128 | ): |
| 126 | 129 | ||
| 127 | self.data = data | 130 | self.data = data |
| 128 | self.tokenizer = tokenizer | 131 | self.prompt_processor = prompt_processor |
| 129 | self.batch_size = batch_size | 132 | self.batch_size = batch_size |
| 130 | self.instance_identifier = instance_identifier | 133 | self.instance_identifier = instance_identifier |
| 131 | self.class_identifier = class_identifier | 134 | self.class_identifier = class_identifier |
| @@ -163,12 +166,6 @@ class CSVDataset(Dataset): | |||
| 163 | 166 | ||
| 164 | example = {} | 167 | example = {} |
| 165 | 168 | ||
| 166 | if isinstance(item.prompt, str): | ||
| 167 | item.prompt = [item.prompt] | ||
| 168 | |||
| 169 | if isinstance(item.nprompt, str): | ||
| 170 | item.nprompt = [item.nprompt] | ||
| 171 | |||
| 172 | example["prompts"] = item.prompt | 169 | example["prompts"] = item.prompt |
| 173 | example["nprompts"] = item.nprompt | 170 | example["nprompts"] = item.nprompt |
| 174 | 171 | ||
| @@ -181,12 +178,9 @@ class CSVDataset(Dataset): | |||
| 181 | self.image_cache[item.instance_image_path] = instance_image | 178 | self.image_cache[item.instance_image_path] = instance_image |
| 182 | 179 | ||
| 183 | example["instance_images"] = instance_image | 180 | example["instance_images"] = instance_image |
| 184 | example["instance_prompt_ids"] = self.tokenizer( | 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 185 | item.prompt.format(self.instance_identifier), | 182 | item.prompt.format(self.instance_identifier) |
| 186 | padding="max_length", | 183 | ) |
| 187 | truncation=True, | ||
| 188 | max_length=self.tokenizer.model_max_length, | ||
| 189 | ).input_ids | ||
| 190 | 184 | ||
| 191 | if self.num_class_images != 0: | 185 | if self.num_class_images != 0: |
| 192 | class_image = Image.open(item.class_image_path) | 186 | class_image = Image.open(item.class_image_path) |
| @@ -194,12 +188,9 @@ class CSVDataset(Dataset): | |||
| 194 | class_image = class_image.convert("RGB") | 188 | class_image = class_image.convert("RGB") |
| 195 | 189 | ||
| 196 | example["class_images"] = class_image | 190 | example["class_images"] = class_image |
| 197 | example["class_prompt_ids"] = self.tokenizer( | 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
| 198 | item.prompt.format(self.class_identifier), | 192 | item.nprompt.format(self.class_identifier) |
| 199 | padding="max_length", | 193 | ) |
| 200 | truncation=True, | ||
| 201 | max_length=self.tokenizer.model_max_length, | ||
| 202 | ).input_ids | ||
| 203 | 194 | ||
| 204 | self.cache[item.instance_image_path] = example | 195 | self.cache[item.instance_image_path] = example |
| 205 | return example | 196 | return example |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -26,6 +26,7 @@ from slugify import slugify | |||
| 26 | from schedulers.scheduling_euler_a import EulerAScheduler | 26 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 29 | from models.clip.prompt import PromptProcessor | ||
| 29 | 30 | ||
| 30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| 31 | 32 | ||
| @@ -147,7 +148,7 @@ def parse_args(): | |||
| 147 | parser.add_argument( | 148 | parser.add_argument( |
| 148 | "--learning_rate_text", | 149 | "--learning_rate_text", |
| 149 | type=float, | 150 | type=float, |
| 150 | default=1e-6, | 151 | default=5e-6, |
| 151 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
| 152 | ) | 153 | ) |
| 153 | parser.add_argument( | 154 | parser.add_argument( |
| @@ -470,7 +471,7 @@ class Checkpointer: | |||
| 470 | for i in range(self.sample_batches): | 471 | for i in range(self.sample_batches): |
| 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 472 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 472 | prompt = [ | 473 | prompt = [ |
| 473 | [p.format(self.instance_identifier) for p in prompt] | 474 | prompt.format(self.instance_identifier) |
| 474 | for batch in batches | 475 | for batch in batches |
| 475 | for prompt in batch["prompts"] | 476 | for prompt in batch["prompts"] |
| 476 | ][:self.sample_batch_size] | 477 | ][:self.sample_batch_size] |
| @@ -573,6 +574,8 @@ def main(): | |||
| 573 | device=accelerator.device | 574 | device=accelerator.device |
| 574 | ) if args.use_ema else None | 575 | ) if args.use_ema else None |
| 575 | 576 | ||
| 577 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 578 | |||
| 576 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
| 577 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
| 578 | 581 | ||
| @@ -663,7 +666,7 @@ def main(): | |||
| 663 | pixel_values = torch.stack(pixel_values) | 666 | pixel_values = torch.stack(pixel_values) |
| 664 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 667 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 665 | 668 | ||
| 666 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 669 | input_ids = prompt_processor.unify_input_ids(input_ids) |
| 667 | 670 | ||
| 668 | batch = { | 671 | batch = { |
| 669 | "prompts": prompts, | 672 | "prompts": prompts, |
| @@ -673,21 +676,10 @@ def main(): | |||
| 673 | } | 676 | } |
| 674 | return batch | 677 | return batch |
| 675 | 678 | ||
| 676 | def encode_input_ids(input_ids): | ||
| 677 | text_embeddings = [] | ||
| 678 | |||
| 679 | for ids in input_ids: | ||
| 680 | embeddings = text_encoder(ids)[0] | ||
| 681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 682 | text_embeddings.append(embeddings) | ||
| 683 | |||
| 684 | text_embeddings = torch.cat(text_embeddings) | ||
| 685 | return text_embeddings | ||
| 686 | |||
| 687 | datamodule = CSVDataModule( | 679 | datamodule = CSVDataModule( |
| 688 | data_file=args.train_data_file, | 680 | data_file=args.train_data_file, |
| 689 | batch_size=args.train_batch_size, | 681 | batch_size=args.train_batch_size, |
| 690 | tokenizer=tokenizer, | 682 | prompt_processor=prompt_processor, |
| 691 | instance_identifier=args.instance_identifier, | 683 | instance_identifier=args.instance_identifier, |
| 692 | class_identifier=args.class_identifier, | 684 | class_identifier=args.class_identifier, |
| 693 | class_subdir="cls", | 685 | class_subdir="cls", |
| @@ -727,7 +719,7 @@ def main(): | |||
| 727 | with torch.inference_mode(): | 719 | with torch.inference_mode(): |
| 728 | for batch in batched_data: | 720 | for batch in batched_data: |
| 729 | image_name = [item.class_image_path for item in batch] | 721 | image_name = [item.class_image_path for item in batch] |
| 730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] | 722 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
| 731 | nprompt = [item.nprompt for item in batch] | 723 | nprompt = [item.nprompt for item in batch] |
| 732 | 724 | ||
| 733 | images = pipeline( | 725 | images = pipeline( |
| @@ -875,7 +867,7 @@ def main(): | |||
| 875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 867 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 876 | 868 | ||
| 877 | # Get the text embedding for conditioning | 869 | # Get the text embedding for conditioning |
| 878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 870 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 879 | 871 | ||
| 880 | # Predict the noise residual | 872 | # Predict the noise residual |
| 881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 873 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -974,7 +966,7 @@ def main(): | |||
| 974 | 966 | ||
| 975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 967 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 976 | 968 | ||
| 977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 978 | 970 | ||
| 979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 971 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 980 | 972 | ||
| @@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | line_sep = " <OR> " | ||
| 23 | |||
| 24 | |||
| 25 | default_args = { | 22 | default_args = { |
| 26 | "model": None, | 23 | "model": None, |
| 27 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| @@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
| 254 | 251 | ||
| 255 | 252 | ||
| 256 | def generate(output_dir, pipeline, args): | 253 | def generate(output_dir, pipeline, args): |
| 254 | if isinstance(args.prompt, str): | ||
| 255 | args.prompt = [args.prompt] | ||
| 256 | |||
| 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
| 260 | 260 | ||
| 261 | seed = args.seed or torch.random.seed() | 261 | seed = args.seed or torch.random.seed() |
| @@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args): | |||
| 276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
| 277 | ) | 277 | ) |
| 278 | 278 | ||
| 279 | if isinstance(args.prompt, str): | ||
| 280 | args.prompt = [args.prompt] | ||
| 281 | |||
| 282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
| 283 | |||
| 284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
| 285 | images = pipeline( | 280 | images = pipeline( |
| 286 | prompt=prompt, | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
| 287 | height=args.height, | 282 | height=args.height, |
| 288 | width=args.width, | 283 | width=args.width, |
| 289 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py new file mode 100644 index 0000000..c1e3340 --- /dev/null +++ b/models/clip/prompt.py | |||
| @@ -0,0 +1,31 @@ | |||
| 1 | from typing import List, Optional, Union | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
| 6 | |||
| 7 | |||
| 8 | class PromptProcessor(): | ||
| 9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
| 10 | self.tokenizer = tokenizer | ||
| 11 | self.text_encoder = text_encoder | ||
| 12 | |||
| 13 | def get_input_ids(self, prompt: Union[str, List[str]]): | ||
| 14 | return self.tokenizer( | ||
| 15 | prompt, | ||
| 16 | padding="do_not_pad", | ||
| 17 | ).input_ids | ||
| 18 | |||
| 19 | def unify_input_ids(self, input_ids: List[int]): | ||
| 20 | return self.tokenizer.pad( | ||
| 21 | {"input_ids": input_ids}, | ||
| 22 | padding=True, | ||
| 23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
| 24 | return_tensors="pt" | ||
| 25 | ).input_ids | ||
| 26 | |||
| 27 | def get_embeddings(self, input_ids: torch.IntTensor): | ||
| 28 | prompts = input_ids.shape[0] | ||
| 29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
| 30 | text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) | ||
| 31 | return text_embeddings | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b68b028..3da0169 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict | |||
| 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from schedulers.scheduling_euler_a import EulerAScheduler | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 15 | from models.clip.prompt import PromptProcessor | ||
| 15 | 16 | ||
| 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 17 | 18 | ||
| @@ -24,22 +25,6 @@ def preprocess(image, w, h): | |||
| 24 | return 2.0 * image - 1.0 | 25 | return 2.0 * image - 1.0 |
| 25 | 26 | ||
| 26 | 27 | ||
| 27 | def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): | ||
| 28 | if isinstance(prompt, str): | ||
| 29 | prompt = [prompt] * batch_size | ||
| 30 | |||
| 31 | if isinstance(prompt, list) and isinstance(prompt[0], str): | ||
| 32 | prompt = [[p] for p in prompt] | ||
| 33 | |||
| 34 | if isinstance(prompt, list) and isinstance(prompt[0], list): | ||
| 35 | prompt_size = prompt_size or max([len(p) for p in prompt]) | ||
| 36 | prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] | ||
| 37 | else: | ||
| 38 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| 39 | |||
| 40 | return prompt_size, prompt | ||
| 41 | |||
| 42 | |||
| 43 | class VlpnStableDiffusion(DiffusionPipeline): | 28 | class VlpnStableDiffusion(DiffusionPipeline): |
| 44 | def __init__( | 29 | def __init__( |
| 45 | self, | 30 | self, |
| @@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 66 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
| 67 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
| 68 | 53 | ||
| 54 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 55 | |||
| 69 | self.register_modules( | 56 | self.register_modules( |
| 70 | vae=vae, | 57 | vae=vae, |
| 71 | text_encoder=text_encoder, | 58 | text_encoder=text_encoder, |
| @@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 101 | # set slice_size = `None` to disable `attention slicing` | 88 | # set slice_size = `None` to disable `attention slicing` |
| 102 | self.enable_attention_slicing(None) | 89 | self.enable_attention_slicing(None) |
| 103 | 90 | ||
| 104 | def embeddings_for_prompt(self, prompt: List[List[str]]): | ||
| 105 | text_embeddings = [] | ||
| 106 | |||
| 107 | for p in prompt: | ||
| 108 | inputs = self.tokenizer( | ||
| 109 | p, | ||
| 110 | padding="max_length", | ||
| 111 | max_length=self.tokenizer.model_max_length, | ||
| 112 | return_tensors="pt", | ||
| 113 | ) | ||
| 114 | input_ids = inputs.input_ids | ||
| 115 | |||
| 116 | if input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
| 117 | removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) | ||
| 118 | logger.warning( | ||
| 119 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
| 120 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
| 121 | ) | ||
| 122 | print(f"Too many tokens: {removed_text}") | ||
| 123 | input_ids = input_ids[:, : self.tokenizer.model_max_length] | ||
| 124 | |||
| 125 | embeddings = self.text_encoder(input_ids.to(self.device))[0] | ||
| 126 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 127 | text_embeddings.append(embeddings) | ||
| 128 | |||
| 129 | text_embeddings = torch.cat(text_embeddings) | ||
| 130 | return text_embeddings | ||
| 131 | |||
| 132 | @torch.no_grad() | 91 | @torch.no_grad() |
| 133 | def __call__( | 92 | def __call__( |
| 134 | self, | 93 | self, |
| @@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 195 | (nsfw) content, according to the `safety_checker`. | 154 | (nsfw) content, according to the `safety_checker`. |
| 196 | """ | 155 | """ |
| 197 | 156 | ||
| 198 | prompt_size, prompt = normalize_prompt(prompt) | 157 | if isinstance(prompt, str): |
| 158 | prompt = [prompt] | ||
| 159 | |||
| 199 | batch_size = len(prompt) | 160 | batch_size = len(prompt) |
| 200 | _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) | ||
| 201 | 161 | ||
| 202 | if len(negative_prompt) != batch_size: | 162 | if isinstance(negative_prompt, str): |
| 163 | negative_prompt = [negative_prompt] * batch_size | ||
| 164 | |||
| 165 | if len(negative_prompt) != len(prompt): | ||
| 203 | raise ValueError( | 166 | raise ValueError( |
| 204 | f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") | 167 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") |
| 205 | 168 | ||
| 206 | if height % 8 != 0 or width % 8 != 0: | 169 | if height % 8 != 0 or width % 8 != 0: |
| 207 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 170 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| @@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 213 | self.scheduler.set_timesteps(num_inference_steps) | 176 | self.scheduler.set_timesteps(num_inference_steps) |
| 214 | 177 | ||
| 215 | # get prompt text embeddings | 178 | # get prompt text embeddings |
| 216 | text_embeddings = self.embeddings_for_prompt(prompt) | 179 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
| 217 | 180 | ||
| 218 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 181 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 219 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 182 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| @@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 221 | do_classifier_free_guidance = guidance_scale > 1.0 | 184 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 222 | # get unconditional embeddings for classifier free guidance | 185 | # get unconditional embeddings for classifier free guidance |
| 223 | if do_classifier_free_guidance: | 186 | if do_classifier_free_guidance: |
| 224 | uncond_embeddings = self.embeddings_for_prompt(negative_prompt) | 187 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) |
| 188 | text_input_ids = unconditional_input_ids + text_input_ids | ||
| 225 | 189 | ||
| 226 | # For classifier free guidance, we need to do two forward passes. | 190 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) |
| 227 | # Here we concatenate the unconditional and text embeddings into a single batch | 191 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) |
| 228 | # to avoid doing two forward passes | ||
| 229 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
| 230 | 192 | ||
| 231 | offset = self.scheduler.config.get("steps_offset", 0) | 193 | offset = self.scheduler.config.get("steps_offset", 0) |
| 232 | init_timestep = num_inference_steps + offset | 194 | init_timestep = num_inference_steps + offset |
