summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py83
-rw-r--r--dreambooth_plus.py28
-rw-r--r--infer.py15
-rw-r--r--models/clip/prompt.py31
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py72
5 files changed, 100 insertions, 129 deletions
diff --git a/data/csv.py b/data/csv.py
index 316c099..4c91ded 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,11 +1,14 @@
1import math 1import math
2import pandas as pd 2import pandas as pd
3import torch
3from pathlib import Path 4from pathlib import Path
4import pytorch_lightning as pl 5import pytorch_lightning as pl
5from PIL import Image 6from PIL import Image
6from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms 8from torchvision import transforms
8from typing import NamedTuple, List 9from typing import NamedTuple, List, Optional
10
11from models.clip.prompt import PromptProcessor
9 12
10 13
11class CSVDataItem(NamedTuple): 14class CSVDataItem(NamedTuple):
@@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple):
18class CSVDataModule(pl.LightningDataModule): 21class CSVDataModule(pl.LightningDataModule):
19 def __init__( 22 def __init__(
20 self, 23 self,
21 batch_size, 24 batch_size: int,
22 data_file, 25 data_file: str,
23 tokenizer, 26 prompt_processor: PromptProcessor,
24 instance_identifier, 27 instance_identifier: str,
25 class_identifier=None, 28 class_identifier: Optional[str] = None,
26 class_subdir="cls", 29 class_subdir: str = "cls",
27 num_class_images=100, 30 num_class_images: int = 100,
28 size=512, 31 size: int = 512,
29 repeats=100, 32 repeats: int = 1,
30 interpolation="bicubic", 33 interpolation: str = "bicubic",
31 center_crop=False, 34 center_crop: bool = False,
32 valid_set_size=None, 35 valid_set_size: Optional[int] = None,
33 generator=None, 36 generator: Optional[torch.Generator] = None,
34 collate_fn=None 37 collate_fn=None
35 ): 38 ):
36 super().__init__() 39 super().__init__()
@@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule):
45 self.class_root.mkdir(parents=True, exist_ok=True) 48 self.class_root.mkdir(parents=True, exist_ok=True)
46 self.num_class_images = num_class_images 49 self.num_class_images = num_class_images
47 50
48 self.tokenizer = tokenizer 51 self.prompt_processor = prompt_processor
49 self.instance_identifier = instance_identifier 52 self.instance_identifier = instance_identifier
50 self.class_identifier = class_identifier 53 self.class_identifier = class_identifier
51 self.size = size 54 self.size = size
@@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule):
65 self.data_root.joinpath(item.image), 68 self.data_root.joinpath(item.image),
66 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), 69 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"),
67 item.prompt, 70 item.prompt,
68 item.nprompt if "nprompt" in item else "" 71 item.nprompt
69 ) 72 )
70 for item in data 73 for item in data
71 for i in range(image_multiplier) 74 for i in range(image_multiplier)
@@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule):
88 self.data_val = self.prepare_subdata(data_val) 91 self.data_val = self.prepare_subdata(data_val)
89 92
90 def setup(self, stage=None): 93 def setup(self, stage=None):
91 train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, 94 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
92 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 95 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
93 num_class_images=self.num_class_images, 96 num_class_images=self.num_class_images,
94 size=self.size, interpolation=self.interpolation, 97 size=self.size, interpolation=self.interpolation,
95 center_crop=self.center_crop, repeats=self.repeats) 98 center_crop=self.center_crop, repeats=self.repeats)
96 val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, 99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
97 instance_identifier=self.instance_identifier, 100 instance_identifier=self.instance_identifier,
98 size=self.size, interpolation=self.interpolation, 101 size=self.size, interpolation=self.interpolation,
99 center_crop=self.center_crop, repeats=self.repeats) 102 center_crop=self.center_crop, repeats=self.repeats)
@@ -113,19 +116,19 @@ class CSVDataset(Dataset):
113 def __init__( 116 def __init__(
114 self, 117 self,
115 data: List[CSVDataItem], 118 data: List[CSVDataItem],
116 tokenizer, 119 prompt_processor: PromptProcessor,
117 instance_identifier, 120 instance_identifier: str,
118 batch_size=1, 121 batch_size: int = 1,
119 class_identifier=None, 122 class_identifier: Optional[str] = None,
120 num_class_images=0, 123 num_class_images: int = 0,
121 size=512, 124 size: int = 512,
122 repeats=1, 125 repeats: int = 1,
123 interpolation="bicubic", 126 interpolation: str = "bicubic",
124 center_crop=False, 127 center_crop: bool = False,
125 ): 128 ):
126 129
127 self.data = data 130 self.data = data
128 self.tokenizer = tokenizer 131 self.prompt_processor = prompt_processor
129 self.batch_size = batch_size 132 self.batch_size = batch_size
130 self.instance_identifier = instance_identifier 133 self.instance_identifier = instance_identifier
131 self.class_identifier = class_identifier 134 self.class_identifier = class_identifier
@@ -163,12 +166,6 @@ class CSVDataset(Dataset):
163 166
164 example = {} 167 example = {}
165 168
166 if isinstance(item.prompt, str):
167 item.prompt = [item.prompt]
168
169 if isinstance(item.nprompt, str):
170 item.nprompt = [item.nprompt]
171
172 example["prompts"] = item.prompt 169 example["prompts"] = item.prompt
173 example["nprompts"] = item.nprompt 170 example["nprompts"] = item.nprompt
174 171
@@ -181,12 +178,9 @@ class CSVDataset(Dataset):
181 self.image_cache[item.instance_image_path] = instance_image 178 self.image_cache[item.instance_image_path] = instance_image
182 179
183 example["instance_images"] = instance_image 180 example["instance_images"] = instance_image
184 example["instance_prompt_ids"] = self.tokenizer( 181 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
185 item.prompt.format(self.instance_identifier), 182 item.prompt.format(self.instance_identifier)
186 padding="max_length", 183 )
187 truncation=True,
188 max_length=self.tokenizer.model_max_length,
189 ).input_ids
190 184
191 if self.num_class_images != 0: 185 if self.num_class_images != 0:
192 class_image = Image.open(item.class_image_path) 186 class_image = Image.open(item.class_image_path)
@@ -194,12 +188,9 @@ class CSVDataset(Dataset):
194 class_image = class_image.convert("RGB") 188 class_image = class_image.convert("RGB")
195 189
196 example["class_images"] = class_image 190 example["class_images"] = class_image
197 example["class_prompt_ids"] = self.tokenizer( 191 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(
198 item.prompt.format(self.class_identifier), 192 item.nprompt.format(self.class_identifier)
199 padding="max_length", 193 )
200 truncation=True,
201 max_length=self.tokenizer.model_max_length,
202 ).input_ids
203 194
204 self.cache[item.instance_image_path] = example 195 self.cache[item.instance_image_path] = example
205 return example 196 return example
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index ae31377..fa3a22b 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -26,6 +26,7 @@ from slugify import slugify
26from schedulers.scheduling_euler_a import EulerAScheduler 26from schedulers.scheduling_euler_a import EulerAScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
31 32
@@ -147,7 +148,7 @@ def parse_args():
147 parser.add_argument( 148 parser.add_argument(
148 "--learning_rate_text", 149 "--learning_rate_text",
149 type=float, 150 type=float,
150 default=1e-6, 151 default=5e-6,
151 help="Initial learning rate (after the potential warmup period) to use.", 152 help="Initial learning rate (after the potential warmup period) to use.",
152 ) 153 )
153 parser.add_argument( 154 parser.add_argument(
@@ -470,7 +471,7 @@ class Checkpointer:
470 for i in range(self.sample_batches): 471 for i in range(self.sample_batches):
471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 472 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
472 prompt = [ 473 prompt = [
473 [p.format(self.instance_identifier) for p in prompt] 474 prompt.format(self.instance_identifier)
474 for batch in batches 475 for batch in batches
475 for prompt in batch["prompts"] 476 for prompt in batch["prompts"]
476 ][:self.sample_batch_size] 477 ][:self.sample_batch_size]
@@ -573,6 +574,8 @@ def main():
573 device=accelerator.device 574 device=accelerator.device
574 ) if args.use_ema else None 575 ) if args.use_ema else None
575 576
577 prompt_processor = PromptProcessor(tokenizer, text_encoder)
578
576 if args.gradient_checkpointing: 579 if args.gradient_checkpointing:
577 unet.enable_gradient_checkpointing() 580 unet.enable_gradient_checkpointing()
578 581
@@ -663,7 +666,7 @@ def main():
663 pixel_values = torch.stack(pixel_values) 666 pixel_values = torch.stack(pixel_values)
664 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 667 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
665 668
666 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 669 input_ids = prompt_processor.unify_input_ids(input_ids)
667 670
668 batch = { 671 batch = {
669 "prompts": prompts, 672 "prompts": prompts,
@@ -673,21 +676,10 @@ def main():
673 } 676 }
674 return batch 677 return batch
675 678
676 def encode_input_ids(input_ids):
677 text_embeddings = []
678
679 for ids in input_ids:
680 embeddings = text_encoder(ids)[0]
681 embeddings = embeddings.reshape((1, -1, 768))
682 text_embeddings.append(embeddings)
683
684 text_embeddings = torch.cat(text_embeddings)
685 return text_embeddings
686
687 datamodule = CSVDataModule( 679 datamodule = CSVDataModule(
688 data_file=args.train_data_file, 680 data_file=args.train_data_file,
689 batch_size=args.train_batch_size, 681 batch_size=args.train_batch_size,
690 tokenizer=tokenizer, 682 prompt_processor=prompt_processor,
691 instance_identifier=args.instance_identifier, 683 instance_identifier=args.instance_identifier,
692 class_identifier=args.class_identifier, 684 class_identifier=args.class_identifier,
693 class_subdir="cls", 685 class_subdir="cls",
@@ -727,7 +719,7 @@ def main():
727 with torch.inference_mode(): 719 with torch.inference_mode():
728 for batch in batched_data: 720 for batch in batched_data:
729 image_name = [item.class_image_path for item in batch] 721 image_name = [item.class_image_path for item in batch]
730 prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] 722 prompt = [item.prompt.format(args.class_identifier) for item in batch]
731 nprompt = [item.nprompt for item in batch] 723 nprompt = [item.nprompt for item in batch]
732 724
733 images = pipeline( 725 images = pipeline(
@@ -875,7 +867,7 @@ def main():
875 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 867 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
876 868
877 # Get the text embedding for conditioning 869 # Get the text embedding for conditioning
878 encoder_hidden_states = encode_input_ids(batch["input_ids"]) 870 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
879 871
880 # Predict the noise residual 872 # Predict the noise residual
881 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 873 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -974,7 +966,7 @@ def main():
974 966
975 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 967 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
976 968
977 encoder_hidden_states = encode_input_ids(batch["input_ids"]) 969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
978 970
979 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 971 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
980 972
diff --git a/infer.py b/infer.py
index d744768..8e17c4e 100644
--- a/infer.py
+++ b/infer.py
@@ -19,9 +19,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19torch.backends.cuda.matmul.allow_tf32 = True 19torch.backends.cuda.matmul.allow_tf32 = True
20 20
21 21
22line_sep = " <OR> "
23
24
25default_args = { 22default_args = {
26 "model": None, 23 "model": None,
27 "scheduler": "euler_a", 24 "scheduler": "euler_a",
@@ -254,8 +251,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
254 251
255 252
256def generate(output_dir, pipeline, args): 253def generate(output_dir, pipeline, args):
254 if isinstance(args.prompt, str):
255 args.prompt = [args.prompt]
256
257 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 257 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") 258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
259 output_dir.mkdir(parents=True, exist_ok=True) 259 output_dir.mkdir(parents=True, exist_ok=True)
260 260
261 seed = args.seed or torch.random.seed() 261 seed = args.seed or torch.random.seed()
@@ -276,14 +276,9 @@ def generate(output_dir, pipeline, args):
276 dynamic_ncols=True 276 dynamic_ncols=True
277 ) 277 )
278 278
279 if isinstance(args.prompt, str):
280 args.prompt = [args.prompt]
281
282 prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size
283
284 generator = torch.Generator(device="cuda").manual_seed(seed + i) 279 generator = torch.Generator(device="cuda").manual_seed(seed + i)
285 images = pipeline( 280 images = pipeline(
286 prompt=prompt, 281 prompt=args.prompt * (args.batch_size // len(args.prompt)),
287 height=args.height, 282 height=args.height,
288 width=args.width, 283 width=args.width,
289 negative_prompt=args.negative_prompt, 284 negative_prompt=args.negative_prompt,
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
new file mode 100644
index 0000000..c1e3340
--- /dev/null
+++ b/models/clip/prompt.py
@@ -0,0 +1,31 @@
1from typing import List, Optional, Union
2
3import torch
4
5from transformers import CLIPTokenizer, CLIPTextModel
6
7
8class PromptProcessor():
9 def __init__(self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):
10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder
12
13 def get_input_ids(self, prompt: Union[str, List[str]]):
14 return self.tokenizer(
15 prompt,
16 padding="do_not_pad",
17 ).input_ids
18
19 def unify_input_ids(self, input_ids: List[int]):
20 return self.tokenizer.pad(
21 {"input_ids": input_ids},
22 padding=True,
23 pad_to_multiple_of=self.tokenizer.model_max_length,
24 return_tensors="pt"
25 ).input_ids
26
27 def get_embeddings(self, input_ids: torch.IntTensor):
28 prompts = input_ids.shape[0]
29 input_ids = input_ids.reshape((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device)
30 text_embeddings = self.text_encoder(input_ids)[0].reshape((prompts, -1, 768))
31 return text_embeddings
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index b68b028..3da0169 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -10,8 +10,9 @@ from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler 14from schedulers.scheduling_euler_a import EulerAScheduler
15from models.clip.prompt import PromptProcessor
15 16
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17 18
@@ -24,22 +25,6 @@ def preprocess(image, w, h):
24 return 2.0 * image - 1.0 25 return 2.0 * image - 1.0
25 26
26 27
27def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None):
28 if isinstance(prompt, str):
29 prompt = [prompt] * batch_size
30
31 if isinstance(prompt, list) and isinstance(prompt[0], str):
32 prompt = [[p] for p in prompt]
33
34 if isinstance(prompt, list) and isinstance(prompt[0], list):
35 prompt_size = prompt_size or max([len(p) for p in prompt])
36 prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt]
37 else:
38 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
39
40 return prompt_size, prompt
41
42
43class VlpnStableDiffusion(DiffusionPipeline): 28class VlpnStableDiffusion(DiffusionPipeline):
44 def __init__( 29 def __init__(
45 self, 30 self,
@@ -66,6 +51,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
66 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
67 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
68 53
54 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
55
69 self.register_modules( 56 self.register_modules(
70 vae=vae, 57 vae=vae,
71 text_encoder=text_encoder, 58 text_encoder=text_encoder,
@@ -101,34 +88,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
101 # set slice_size = `None` to disable `attention slicing` 88 # set slice_size = `None` to disable `attention slicing`
102 self.enable_attention_slicing(None) 89 self.enable_attention_slicing(None)
103 90
104 def embeddings_for_prompt(self, prompt: List[List[str]]):
105 text_embeddings = []
106
107 for p in prompt:
108 inputs = self.tokenizer(
109 p,
110 padding="max_length",
111 max_length=self.tokenizer.model_max_length,
112 return_tensors="pt",
113 )
114 input_ids = inputs.input_ids
115
116 if input_ids.shape[-1] > self.tokenizer.model_max_length:
117 removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:])
118 logger.warning(
119 "The following part of your input was truncated because CLIP can only handle sequences up to"
120 f" {self.tokenizer.model_max_length} tokens: {removed_text}"
121 )
122 print(f"Too many tokens: {removed_text}")
123 input_ids = input_ids[:, : self.tokenizer.model_max_length]
124
125 embeddings = self.text_encoder(input_ids.to(self.device))[0]
126 embeddings = embeddings.reshape((1, -1, 768))
127 text_embeddings.append(embeddings)
128
129 text_embeddings = torch.cat(text_embeddings)
130 return text_embeddings
131
132 @torch.no_grad() 91 @torch.no_grad()
133 def __call__( 92 def __call__(
134 self, 93 self,
@@ -195,13 +154,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 (nsfw) content, according to the `safety_checker`. 154 (nsfw) content, according to the `safety_checker`.
196 """ 155 """
197 156
198 prompt_size, prompt = normalize_prompt(prompt) 157 if isinstance(prompt, str):
158 prompt = [prompt]
159
199 batch_size = len(prompt) 160 batch_size = len(prompt)
200 _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size)
201 161
202 if len(negative_prompt) != batch_size: 162 if isinstance(negative_prompt, str):
163 negative_prompt = [negative_prompt] * batch_size
164
165 if len(negative_prompt) != len(prompt):
203 raise ValueError( 166 raise ValueError(
204 f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") 167 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}")
205 168
206 if height % 8 != 0 or width % 8 != 0: 169 if height % 8 != 0 or width % 8 != 0:
207 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 170 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -213,7 +176,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
213 self.scheduler.set_timesteps(num_inference_steps) 176 self.scheduler.set_timesteps(num_inference_steps)
214 177
215 # get prompt text embeddings 178 # get prompt text embeddings
216 text_embeddings = self.embeddings_for_prompt(prompt) 179 text_input_ids = self.prompt_processor.get_input_ids(prompt)
217 180
218 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 181 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
219 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 182 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -221,12 +184,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
221 do_classifier_free_guidance = guidance_scale > 1.0 184 do_classifier_free_guidance = guidance_scale > 1.0
222 # get unconditional embeddings for classifier free guidance 185 # get unconditional embeddings for classifier free guidance
223 if do_classifier_free_guidance: 186 if do_classifier_free_guidance:
224 uncond_embeddings = self.embeddings_for_prompt(negative_prompt) 187 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt)
188 text_input_ids = unconditional_input_ids + text_input_ids
225 189
226 # For classifier free guidance, we need to do two forward passes. 190 text_input_ids = self.prompt_processor.unify_input_ids(text_input_ids)
227 # Here we concatenate the unconditional and text embeddings into a single batch 191 text_embeddings = self.prompt_processor.get_embeddings(text_input_ids)
228 # to avoid doing two forward passes
229 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
230 192
231 offset = self.scheduler.config.get("steps_offset", 0) 193 offset = self.scheduler.config.get("steps_offset", 0)
232 init_timestep = num_inference_steps + offset 194 init_timestep = num_inference_steps + offset