diff options
author | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-18 15:23:40 +0200 |
commit | 306f2bfb620e6882737658bd3694c79365d75e4b (patch) | |
tree | 8b461c4360b9baa5758c2af0100348f14df8c76d | |
parent | Implemented extended prompt limit (diff) | |
download | textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.gz textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.tar.bz2 textual-inversion-diff-306f2bfb620e6882737658bd3694c79365d75e4b.zip |
Improved prompt handling
-rw-r--r-- | data/csv.py | 83 | ||||
-rw-r--r-- | dreambooth_plus.py | 28 | ||||
-rw-r--r-- | infer.py | 15 | ||||
-rw-r--r-- | models/clip/prompt.py | 31 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 72 |
5 files changed, 100 insertions, 129 deletions
diff --git a/data/csv.py b/data/csv.py index 316c099..4c91ded 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,11 +1,14 @@ | |||
1 | import math | 1 | import math |
2 | import pandas as pd | 2 | import pandas as pd |
3 | import torch | ||
3 | from pathlib import Path | 4 | from pathlib import Path |
4 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
5 | from PIL import Image | 6 | from PIL import Image |
6 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
7 | from torchvision import transforms | 8 | from torchvision import transforms |
8 | from typing import NamedTuple, List | 9 | from typing import NamedTuple, List, Optional |
10 | |||
11 | from models.clip.prompt import PromptProcessor | ||
9 | 12 | ||
10 | 13 | ||
11 | class CSVDataItem(NamedTuple): | 14 | class CSVDataItem(NamedTuple): |
@@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple): | |||
18 | class CSVDataModule(pl.LightningDataModule): | 21 | class CSVDataModule(pl.LightningDataModule): |
19 | def __init__( | 22 | def __init__( |
20 | self, | 23 | self, |
21 | batch_size, | 24 | batch_size: int, |
22 | data_file, | 25 | data_file: str, |
23 | tokenizer, | 26 | prompt_processor: PromptProcessor, |
24 | instance_identifier, | 27 | instance_identifier: str, |
25 | class_identifier=None, | 28 | class_identifier: Optional[str] = None, |
26 | class_subdir="cls", | 29 | class_subdir: str = "cls", |
27 | num_class_images=100, | 30 | num_class_images: int = 100, |
28 | size=512, | 31 | size: int = 512, |
29 | repeats=100, | 32 | repeats: int = 1, |
30 | interpolation="bicubic", | 33 | interpolation: str = "bicubic", |
31 | center_crop=False, | 34 | center_crop: bool = False, |
32 | valid_set_size=None, | 35 | valid_set_size: Optional[int] = None, |
33 | generator=None, | 36 | generator: Optional[torch.Generator] = None, |
34 | collate_fn=None | 37 | collate_fn=None |
35 | ): | 38 | ): |
36 | super().__init__() | 39 | super().__init__() |
@@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
45 | self.class_root.mkdir(parents=True, exist_ok=True) | 48 | self.class_root.mkdir(parents=True, exist_ok=True) |
46 | self.num_class_images = num_class_images | 49 | self.num_class_images = num_class_images |
47 | 50 | ||
48 | self.tokenizer = tokenizer | 51 | self.prompt_processor = prompt_processor |
49 | self.instance_identifier = instance_identifier | 52 | self.instance_identifier = instance_identifier |
50 | self.class_identifier = class_identifier | 53 | self.class_identifier = class_identifier |
51 | self.size = size | 54 | self.size = size |
@@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
65 | self.data_root.joinpath(item.image), | 68 | self.data_root.joinpath(item.image), |
66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), |
67 | item.prompt, | 70 | item.prompt, |
68 | item.nprompt if "nprompt" in item else "" | 71 | item.nprompt |
69 | ) | 72 | ) |
70 | for item in data | 73 | for item in data |
71 | for i in range(image_multiplier) | 74 | for i in range(image_multiplier) |
@@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
88 | self.data_val = self.prepare_subdata(data_val) | 91 | self.data_val = self.prepare_subdata(data_val) |
89 | 92 | ||
90 | def setup(self, stage=None): | 93 | def setup(self, stage=None): |
91 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
92 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 95 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
93 | num_class_images=self.num_class_images, | 96 | num_class_images=self.num_class_images, |
94 | size=self.size, interpolation=self.interpolation, | 97 | size=self.size, interpolation=self.interpolation, |
95 | center_crop=self.center_crop, repeats=self.repeats) | 98 | center_crop=self.center_crop, repeats=self.repeats) |
96 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, | 99 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
97 | instance_identifier=self.instance_identifier, | 100 | instance_identifier=self.instance_identifier, |
98 | size=self.size, interpolation=self.interpolation, | 101 | size=self.size, interpolation=self.interpolation, |
99 | center_crop=self.center_crop, repeats=self.repeats) | 102 | center_crop=self.center_crop, repeats=self.repeats) |
@@ -113,19 +116,19 @@ class CSVDataset(Dataset): | |||
113 | def __init__( | 116 | def __init__( |
114 | self, | 117 | self, |
115 | data: List[CSVDataItem], | 118 | data: List[CSVDataItem], |
116 | tokenizer, | 119 | prompt_processor: PromptProcessor, |
117 | instance_identifier, | 120 | instance_identifier: str, |
118 | batch_size=1, | 121 | batch_size: int = 1, |
119 | class_identifier=None, | 122 | class_identifier: Optional[str] = None, |
120 | num_class_images=0, | 123 | num_class_images: int = 0, |
121 | size=512, | 124 | size: int = 512, |
122 | repeats=1, | 125 | repeats: int = 1, |
123 | interpolation="bicubic", | 126 | interpolation: str = "bicubic", |
124 | center_crop=False, | 127 | center_crop: bool = False, |
125 | ): | 128 | ): |
126 | 129 | ||
127 | self.data = data | 130 | self.data = data |
128 | self.tokenizer = tokenizer | 131 | self.prompt_processor = prompt_processor |
129 | self.batch_size = batch_size | 132 | self.batch_size = batch_size |
130 | self.instance_identifier = instance_identifier | 133 | self.instance_identifier = instance_identifier |
131 | self.class_identifier = class_identifier | 134 | self.class_identifier = class_identifier |
@@ -163,12 +166,6 @@ class CSVDataset(Dataset): | |||
163 | 166 | ||
164 | example = {} | 167 | example = {} |
165 | 168 | ||
166 | if isinstance(item.prompt, str): | ||
167 | item.prompt = [item.prompt] | ||
168 | |||
169 | if isinstance(item.nprompt, str): | ||
170 | item.nprompt = [item.nprompt] | ||
171 | |||
172 | example["prompts"] = item.prompt | 169 | example["prompts"] = item.prompt |
173 | example["nprompts"] = item.nprompt | 170 | example["nprompts"] = item.nprompt |
174 | 171 | ||
@@ -181,12 +178,9 @@ class CSVDataset(Dataset): | |||
181 | self.image_cache[item.instance_image_path] = instance_image | 178 | self.image_cache[item.instance_image_path] = instance_image |
182 | 179 | ||
183 | example["instance_images"] = instance_image | 180 | example["instance_images"] = instance_image |
184 | example["instance_prompt_ids"] = self.tokenizer( | 181 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
185 | item.prompt.format(self.instance_identifier), | 182 | item.prompt.format(self.instance_identifier) |
186 | padding="max_length", | 183 | ) |
187 | truncation=True, | ||
188 | max_length=self.tokenizer.model_max_length, | ||
189 | ).input_ids | ||
190 | 184 | ||
191 | if self.num_class_images != 0: | 185 | if self.num_class_images != 0: |
192 | class_image = Image.open(item.class_image_path) | 186 | class_image = Image.open(item.class_image_path) |
@@ -194,12 +188,9 @@ class CSVDataset(Dataset): | |||
194 | class_image = class_image.convert("RGB") | 188 | class_image = class_image.convert("RGB") |
195 | 189 | ||
196 | example["class_images"] = class_image | 190 | example["class_images"] = class_image |
197 | example["class_prompt_ids"] = self.tokenizer( | 191 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( |
198 | item.prompt.format(self.class_identifier), | 192 | item.nprompt.format(self.class_identifier) |
199 | padding="max_length", | 193 | ) |
200 | truncation=True, | ||
201 | max_length=self.tokenizer.model_max_length, | ||
202 | ).input_ids | ||
203 | 194 | ||
204 | self.cache[item.instance_image_path] = example | 195 | self.cache[item.instance_image_path] = example |
205 | return example | 196 | return example |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -26,6 +26,7 @@ from slugify import slugify | |||
26 | from schedulers.scheduling_euler_a import EulerAScheduler | 26 | from schedulers.scheduling_euler_a import EulerAScheduler |
27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
29 | from models.clip.prompt import PromptProcessor | ||
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
31 | 32 | ||
@@ -147,7 +148,7 @@ def parse_args(): | |||
147 | parser.add_argument( | 148 | parser.add_argument( |
148 | "--learning_rate_text", | 149 | "--learning_rate_text", |
149 | type=float, | 150 | type=float, |
150 | default=1e-6, | 151 | default=5e-6, |
151 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
152 | ) | 153 | ) |
153 | parser.add_argument( | 154 | parser.add_argument( |
@@ -470,7 +471,7 @@ class Checkpointer: | |||
470 | for i in range(self.sample_batches): | 471 | for i in range(self.sample_batches): |
471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 472 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
472 | prompt = [ | 473 | prompt = [ |
473 | [p.format(self.instance_identifier) for p in prompt] | 474 | prompt.format(self.instance_identifier) |
474 | for batch in batches | 475 | for batch in batches |
475 | for prompt in batch["prompts"] | 476 | for prompt in batch["prompts"] |
476 | ][:self.sample_batch_size] | 477 | ][:self.sample_batch_size] |
@@ -573,6 +574,8 @@ def main(): | |||
573 | device=accelerator.device | 574 | device=accelerator.device |
574 | ) if args.use_ema else None | 575 | ) if args.use_ema else None |
575 | 576 | ||
577 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
578 | |||
576 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
577 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
578 | 581 | ||
@@ -663,7 +666,7 @@ def main(): | |||
663 | pixel_values = torch.stack(pixel_values) | 666 | pixel_values = torch.stack(pixel_values) |
664 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 667 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
665 | 668 | ||
666 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 669 | input_ids = prompt_processor.unify_input_ids(input_ids) |
667 | 670 | ||
668 | batch = { | 671 | batch = { |
669 | "prompts": prompts, | 672 | "prompts": prompts, |
@@ -673,21 +676,10 @@ def main(): | |||
673 | } | 676 | } |
674 | return batch | 677 | return batch |
675 | 678 | ||
676 | def encode_input_ids(input_ids): | ||
677 | text_embeddings = [] | ||
678 | |||
679 | for ids in input_ids: | ||
680 | embeddings = text_encoder(ids)[0] | ||
681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
682 | text_embeddings.append(embeddings) | ||
683 | |||
684 | text_embeddings = torch.cat(text_embeddings) | ||
685 | return text_embeddings | ||
686 | |||
687 | datamodule = CSVDataModule( | 679 | datamodule = CSVDataModule( |
688 | data_file=args.train_data_file, | 680 | data_file=args.train_data_file, |
689 | batch_size=args.train_batch_size, | 681 | batch_size=args.train_batch_size, |
690 | tokenizer=tokenizer, | 682 | prompt_processor=prompt_processor, |
691 | instance_identifier=args.instance_identifier, | 683 | instance_identifier=args.instance_identifier, |
692 | class_identifier=args.class_identifier, | 684 | class_identifier=args.class_identifier, |
693 | class_subdir="cls", | 685 | class_subdir="cls", |
@@ -727,7 +719,7 @@ def main(): | |||
727 | with torch.inference_mode(): | 719 | with torch.inference_mode(): |
728 | for batch in batched_data: | 720 | for batch in batched_data: |
729 | image_name = [item.class_image_path for item in batch] | 721 | image_name = [item.class_image_path for item in batch] |
730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] | 722 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
731 | nprompt = [item.nprompt for item in batch] | 723 | nprompt = [item.nprompt for item in batch] |
732 | 724 | ||
733 | images = pipeline( | 725 | images = pipeline( |
@@ -875,7 +867,7 @@ def main(): | |||
875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 867 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
876 | 868 | ||
877 | # Get the text embedding for conditioning | 869 | # Get the text embedding for conditioning |
878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 870 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
879 | 871 | ||
880 | # Predict the noise residual | 872 | # Predict the noise residual |
881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 873 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -974,7 +966,7 @@ def main(): | |||
974 | 966 | ||
975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 967 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
976 | 968 | ||
977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
978 | 970 | ||
979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 971 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
980 | 972 | ||
@@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
20 | 20 | ||
21 | 21 | ||
22 | line_sep = " <OR> " | ||
23 | |||
24 | |||
25 | default_args = { | 22 | default_args = { |
26 | "model": None, | 23 | "model": None, |
27 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
@@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
254 | 251 | ||
255 | 252 | ||
256 | def generate(output_dir, pipeline, args): | 253 | def generate(output_dir, pipeline, args): |
254 | if isinstance(args.prompt, str): | ||
255 | args.prompt = [args.prompt] | ||
256 | |||
257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 257 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
260 | 260 | ||
261 | seed = args.seed or torch.random.seed() | 261 | seed = args.seed or torch.random.seed() |
@@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args): | |||
276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
277 | ) | 277 | ) |
278 | 278 | ||
279 | if isinstance(args.prompt, str): | ||
280 | args.prompt = [args.prompt] | ||
281 | |||
282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
283 | |||
284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
285 | images = pipeline( | 280 | images = pipeline( |
286 | prompt=prompt, | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
287 | height=args.height, | 282 | height=args.height, |
288 | width=args.width, | 283 | width=args.width, |
289 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |
diff --git a/models/clip/prompt.py b/models/clip/prompt.py new file mode 100644 index 0000000..c1e3340 --- /dev/null +++ b/models/clip/prompt.py | |||
@@ -0,0 +1,31 @@ | |||
1 | from typing import List, Optional, Union | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | |||
7 | |||
8 | class PromptProcessor(): | ||
9 | def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): | ||
10 | self.tokenizer = tokenizer | ||
11 | self.text_encoder = text_encoder | ||
12 | |||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | ||
14 | return self.tokenizer( | ||
15 | prompt, | ||
16 | padding="do_not_pad", | ||
17 | ).input_ids | ||
18 | |||
19 | def unify_input_ids(self, input_ids: List[int]): | ||
20 | return self.tokenizer.pad( | ||
21 | {"input_ids": input_ids}, | ||
22 | padding=True, | ||
23 | pad_to_multiple_of=self.tokenizer.model_max_length, | ||
24 | return_tensors="pt" | ||
25 | ).input_ids | ||
26 | |||
27 | def get_embeddings(self, input_ids: torch.IntTensor): | ||
28 | prompts = input_ids.shape[0] | ||
29 | input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) | ||
30 | text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768)) | ||
31 | return text_embeddings | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index b68b028..3da0169 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict | |||
10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from schedulers.scheduling_euler_a import EulerAScheduler | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
15 | from models.clip.prompt import PromptProcessor | ||
15 | 16 | ||
16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
17 | 18 | ||
@@ -24,22 +25,6 @@ def preprocess(image, w, h): | |||
24 | return 2.0 * image - 1.0 | 25 | return 2.0 * image - 1.0 |
25 | 26 | ||
26 | 27 | ||
27 | def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): | ||
28 | if isinstance(prompt, str): | ||
29 | prompt = [prompt] * batch_size | ||
30 | |||
31 | if isinstance(prompt, list) and isinstance(prompt[0], str): | ||
32 | prompt = [[p] for p in prompt] | ||
33 | |||
34 | if isinstance(prompt, list) and isinstance(prompt[0], list): | ||
35 | prompt_size = prompt_size or max([len(p) for p in prompt]) | ||
36 | prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] | ||
37 | else: | ||
38 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
39 | |||
40 | return prompt_size, prompt | ||
41 | |||
42 | |||
43 | class VlpnStableDiffusion(DiffusionPipeline): | 28 | class VlpnStableDiffusion(DiffusionPipeline): |
44 | def __init__( | 29 | def __init__( |
45 | self, | 30 | self, |
@@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
66 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
67 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
68 | 53 | ||
54 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
55 | |||
69 | self.register_modules( | 56 | self.register_modules( |
70 | vae=vae, | 57 | vae=vae, |
71 | text_encoder=text_encoder, | 58 | text_encoder=text_encoder, |
@@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
101 | # set slice_size = `None` to disable `attention slicing` | 88 | # set slice_size = `None` to disable `attention slicing` |
102 | self.enable_attention_slicing(None) | 89 | self.enable_attention_slicing(None) |
103 | 90 | ||
104 | def embeddings_for_prompt(self, prompt: List[List[str]]): | ||
105 | text_embeddings = [] | ||
106 | |||
107 | for p in prompt: | ||
108 | inputs = self.tokenizer( | ||
109 | p, | ||
110 | padding="max_length", | ||
111 | max_length=self.tokenizer.model_max_length, | ||
112 | return_tensors="pt", | ||
113 | ) | ||
114 | input_ids = inputs.input_ids | ||
115 | |||
116 | if input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
117 | removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) | ||
118 | logger.warning( | ||
119 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
120 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
121 | ) | ||
122 | print(f"Too many tokens: {removed_text}") | ||
123 | input_ids = input_ids[:, : self.tokenizer.model_max_length] | ||
124 | |||
125 | embeddings = self.text_encoder(input_ids.to(self.device))[0] | ||
126 | embeddings = embeddings.reshape((1, -1, 768)) | ||
127 | text_embeddings.append(embeddings) | ||
128 | |||
129 | text_embeddings = torch.cat(text_embeddings) | ||
130 | return text_embeddings | ||
131 | |||
132 | @torch.no_grad() | 91 | @torch.no_grad() |
133 | def __call__( | 92 | def __call__( |
134 | self, | 93 | self, |
@@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
195 | (nsfw) content, according to the `safety_checker`. | 154 | (nsfw) content, according to the `safety_checker`. |
196 | """ | 155 | """ |
197 | 156 | ||
198 | prompt_size, prompt = normalize_prompt(prompt) | 157 | if isinstance(prompt, str): |
158 | prompt = [prompt] | ||
159 | |||
199 | batch_size = len(prompt) | 160 | batch_size = len(prompt) |
200 | _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) | ||
201 | 161 | ||
202 | if len(negative_prompt) != batch_size: | 162 | if isinstance(negative_prompt, str): |
163 | negative_prompt = [negative_prompt] * batch_size | ||
164 | |||
165 | if len(negative_prompt) != len(prompt): | ||
203 | raise ValueError( | 166 | raise ValueError( |
204 | f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") | 167 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") |
205 | 168 | ||
206 | if height % 8 != 0 or width % 8 != 0: | 169 | if height % 8 != 0 or width % 8 != 0: |
207 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 170 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
@@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
213 | self.scheduler.set_timesteps(num_inference_steps) | 176 | self.scheduler.set_timesteps(num_inference_steps) |
214 | 177 | ||
215 | # get prompt text embeddings | 178 | # get prompt text embeddings |
216 | text_embeddings = self.embeddings_for_prompt(prompt) | 179 | text_input_ids = self.prompt_processor.get_input_ids(prompt) |
217 | 180 | ||
218 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 181 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
219 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 182 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
@@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
221 | do_classifier_free_guidance = guidance_scale > 1.0 | 184 | do_classifier_free_guidance = guidance_scale > 1.0 |
222 | # get unconditional embeddings for classifier free guidance | 185 | # get unconditional embeddings for classifier free guidance |
223 | if do_classifier_free_guidance: | 186 | if do_classifier_free_guidance: |
224 | uncond_embeddings = self.embeddings_for_prompt(negative_prompt) | 187 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) |
188 | text_input_ids = unconditional_input_ids + text_input_ids | ||
225 | 189 | ||
226 | # For classifier free guidance, we need to do two forward passes. | 190 | text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids) |
227 | # Here we concatenate the unconditional and text embeddings into a single batch | 191 | text_embeddings = self.prompt_processor.get_embeddings(text_input_ids) |
228 | # to avoid doing two forward passes | ||
229 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | ||
230 | 192 | ||
231 | offset = self.scheduler.config.get("steps_offset", 0) | 193 | offset = self.scheduler.config.get("steps_offset", 0) |
232 | init_timestep = num_inference_steps + offset | 194 | init_timestep = num_inference_steps + offset |