diff options
| -rw-r--r-- | data/csv.py | 14 | ||||
| -rw-r--r-- | dreambooth_plus.py | 44 | ||||
| -rw-r--r-- | infer.py | 12 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 94 |
4 files changed, 107 insertions, 57 deletions
diff --git a/data/csv.py b/data/csv.py index aad970c..316c099 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -72,8 +72,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 72 | ] | 72 | ] |
| 73 | 73 | ||
| 74 | def prepare_data(self): | 74 | def prepare_data(self): |
| 75 | metadata = pd.read_csv(self.data_file) | 75 | metadata = pd.read_json(self.data_file) |
| 76 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] | 76 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] |
| 77 | num_images = len(metadata) | 77 | num_images = len(metadata) |
| 78 | 78 | ||
| 79 | valid_set_size = int(num_images * 0.2) | 79 | valid_set_size = int(num_images * 0.2) |
| @@ -163,6 +163,12 @@ class CSVDataset(Dataset): | |||
| 163 | 163 | ||
| 164 | example = {} | 164 | example = {} |
| 165 | 165 | ||
| 166 | if isinstance(item.prompt, str): | ||
| 167 | item.prompt = [item.prompt] | ||
| 168 | |||
| 169 | if isinstance(item.nprompt, str): | ||
| 170 | item.nprompt = [item.nprompt] | ||
| 171 | |||
| 166 | example["prompts"] = item.prompt | 172 | example["prompts"] = item.prompt |
| 167 | example["nprompts"] = item.nprompt | 173 | example["nprompts"] = item.nprompt |
| 168 | 174 | ||
| @@ -177,7 +183,7 @@ class CSVDataset(Dataset): | |||
| 177 | example["instance_images"] = instance_image | 183 | example["instance_images"] = instance_image |
| 178 | example["instance_prompt_ids"] = self.tokenizer( | 184 | example["instance_prompt_ids"] = self.tokenizer( |
| 179 | item.prompt.format(self.instance_identifier), | 185 | item.prompt.format(self.instance_identifier), |
| 180 | padding="do_not_pad", | 186 | padding="max_length", |
| 181 | truncation=True, | 187 | truncation=True, |
| 182 | max_length=self.tokenizer.model_max_length, | 188 | max_length=self.tokenizer.model_max_length, |
| 183 | ).input_ids | 189 | ).input_ids |
| @@ -190,7 +196,7 @@ class CSVDataset(Dataset): | |||
| 190 | example["class_images"] = class_image | 196 | example["class_images"] = class_image |
| 191 | example["class_prompt_ids"] = self.tokenizer( | 197 | example["class_prompt_ids"] = self.tokenizer( |
| 192 | item.prompt.format(self.class_identifier), | 198 | item.prompt.format(self.class_identifier), |
| 193 | padding="do_not_pad", | 199 | padding="max_length", |
| 194 | truncation=True, | 200 | truncation=True, |
| 195 | max_length=self.tokenizer.model_max_length, | 201 | max_length=self.tokenizer.model_max_length, |
| 196 | ).input_ids | 202 | ).input_ids |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index a98417f..ae31377 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -124,7 +124,7 @@ def parse_args(): | |||
| 124 | parser.add_argument( | 124 | parser.add_argument( |
| 125 | "--max_train_steps", | 125 | "--max_train_steps", |
| 126 | type=int, | 126 | type=int, |
| 127 | default=1500, | 127 | default=1400, |
| 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 128 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 129 | ) | 129 | ) |
| 130 | parser.add_argument( | 130 | parser.add_argument( |
| @@ -147,7 +147,7 @@ def parse_args(): | |||
| 147 | parser.add_argument( | 147 | parser.add_argument( |
| 148 | "--learning_rate_text", | 148 | "--learning_rate_text", |
| 149 | type=float, | 149 | type=float, |
| 150 | default=5e-6, | 150 | default=1e-6, |
| 151 | help="Initial learning rate (after the potential warmup period) to use.", | 151 | help="Initial learning rate (after the potential warmup period) to use.", |
| 152 | ) | 152 | ) |
| 153 | parser.add_argument( | 153 | parser.add_argument( |
| @@ -469,9 +469,16 @@ class Checkpointer: | |||
| 469 | 469 | ||
| 470 | for i in range(self.sample_batches): | 470 | for i in range(self.sample_batches): |
| 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 472 | prompt = [prompt.format(self.instance_identifier) | 472 | prompt = [ |
| 473 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 473 | [p.format(self.instance_identifier) for p in prompt] |
| 474 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 474 | for batch in batches |
| 475 | for prompt in batch["prompts"] | ||
| 476 | ][:self.sample_batch_size] | ||
| 477 | nprompt = [ | ||
| 478 | prompt | ||
| 479 | for batch in batches | ||
| 480 | for prompt in batch["nprompts"] | ||
| 481 | ][:self.sample_batch_size] | ||
| 475 | 482 | ||
| 476 | samples = pipeline( | 483 | samples = pipeline( |
| 477 | prompt=prompt, | 484 | prompt=prompt, |
| @@ -666,6 +673,17 @@ def main(): | |||
| 666 | } | 673 | } |
| 667 | return batch | 674 | return batch |
| 668 | 675 | ||
| 676 | def encode_input_ids(input_ids): | ||
| 677 | text_embeddings = [] | ||
| 678 | |||
| 679 | for ids in input_ids: | ||
| 680 | embeddings = text_encoder(ids)[0] | ||
| 681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 682 | text_embeddings.append(embeddings) | ||
| 683 | |||
| 684 | text_embeddings = torch.cat(text_embeddings) | ||
| 685 | return text_embeddings | ||
| 686 | |||
| 669 | datamodule = CSVDataModule( | 687 | datamodule = CSVDataModule( |
| 670 | data_file=args.train_data_file, | 688 | data_file=args.train_data_file, |
| 671 | batch_size=args.train_batch_size, | 689 | batch_size=args.train_batch_size, |
| @@ -688,8 +706,10 @@ def main(): | |||
| 688 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 706 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 689 | 707 | ||
| 690 | if len(missing_data) != 0: | 708 | if len(missing_data) != 0: |
| 691 | batched_data = [missing_data[i:i+args.sample_batch_size] | 709 | batched_data = [ |
| 692 | for i in range(0, len(missing_data), args.sample_batch_size)] | 710 | missing_data[i:i+args.sample_batch_size] |
| 711 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 712 | ] | ||
| 693 | 713 | ||
| 694 | scheduler = EulerAScheduler( | 714 | scheduler = EulerAScheduler( |
| 695 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 715 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| @@ -706,9 +726,9 @@ def main(): | |||
| 706 | 726 | ||
| 707 | with torch.inference_mode(): | 727 | with torch.inference_mode(): |
| 708 | for batch in batched_data: | 728 | for batch in batched_data: |
| 709 | image_name = [p.class_image_path for p in batch] | 729 | image_name = [item.class_image_path for item in batch] |
| 710 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] |
| 711 | nprompt = [p.nprompt for p in batch] | 731 | nprompt = [item.nprompt for item in batch] |
| 712 | 732 | ||
| 713 | images = pipeline( | 733 | images = pipeline( |
| 714 | prompt=prompt, | 734 | prompt=prompt, |
| @@ -855,7 +875,7 @@ def main(): | |||
| 855 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 856 | 876 | ||
| 857 | # Get the text embedding for conditioning | 877 | # Get the text embedding for conditioning |
| 858 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
| 859 | 879 | ||
| 860 | # Predict the noise residual | 880 | # Predict the noise residual |
| 861 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -954,7 +974,7 @@ def main(): | |||
| 954 | 974 | ||
| 955 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 956 | 976 | ||
| 957 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) |
| 958 | 978 | ||
| 959 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 960 | 980 | ||
| @@ -19,6 +19,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | line_sep = " <OR> " | ||
| 23 | |||
| 24 | |||
| 22 | default_args = { | 25 | default_args = { |
| 23 | "model": None, | 26 | "model": None, |
| 24 | "scheduler": "euler_a", | 27 | "scheduler": "euler_a", |
| @@ -95,10 +98,12 @@ def create_cmd_parser(): | |||
| 95 | parser.add_argument( | 98 | parser.add_argument( |
| 96 | "--prompt", | 99 | "--prompt", |
| 97 | type=str, | 100 | type=str, |
| 101 | nargs="+", | ||
| 98 | ) | 102 | ) |
| 99 | parser.add_argument( | 103 | parser.add_argument( |
| 100 | "--negative_prompt", | 104 | "--negative_prompt", |
| 101 | type=str, | 105 | type=str, |
| 106 | nargs="*", | ||
| 102 | ) | 107 | ) |
| 103 | parser.add_argument( | 108 | parser.add_argument( |
| 104 | "--image", | 109 | "--image", |
| @@ -271,9 +276,14 @@ def generate(output_dir, pipeline, args): | |||
| 271 | dynamic_ncols=True | 276 | dynamic_ncols=True |
| 272 | ) | 277 | ) |
| 273 | 278 | ||
| 279 | if isinstance(args.prompt, str): | ||
| 280 | args.prompt = [args.prompt] | ||
| 281 | |||
| 282 | prompt = [p.split(line_sep) for p in args.prompt] * args.batch_size | ||
| 283 | |||
| 274 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 284 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
| 275 | images = pipeline( | 285 | images = pipeline( |
| 276 | prompt=[args.prompt] * args.batch_size, | 286 | prompt=prompt, |
| 277 | height=args.height, | 287 | height=args.height, |
| 278 | width=args.width, | 288 | width=args.width, |
| 279 | negative_prompt=args.negative_prompt, | 289 | negative_prompt=args.negative_prompt, |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8b08a6f..b68b028 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -24,6 +24,22 @@ def preprocess(image, w, h): | |||
| 24 | return 2.0 * image - 1.0 | 24 | return 2.0 * image - 1.0 |
| 25 | 25 | ||
| 26 | 26 | ||
| 27 | def normalize_prompt(prompt: Union[str, List[str], List[List[str]]], batch_size: int = 1, prompt_size: int = None): | ||
| 28 | if isinstance(prompt, str): | ||
| 29 | prompt = [prompt] * batch_size | ||
| 30 | |||
| 31 | if isinstance(prompt, list) and isinstance(prompt[0], str): | ||
| 32 | prompt = [[p] for p in prompt] | ||
| 33 | |||
| 34 | if isinstance(prompt, list) and isinstance(prompt[0], list): | ||
| 35 | prompt_size = prompt_size or max([len(p) for p in prompt]) | ||
| 36 | prompt: List[List[str]] = [subprompt + [""] * (prompt_size - len(subprompt)) for subprompt in prompt] | ||
| 37 | else: | ||
| 38 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| 39 | |||
| 40 | return prompt_size, prompt | ||
| 41 | |||
| 42 | |||
| 27 | class VlpnStableDiffusion(DiffusionPipeline): | 43 | class VlpnStableDiffusion(DiffusionPipeline): |
| 28 | def __init__( | 44 | def __init__( |
| 29 | self, | 45 | self, |
| @@ -85,11 +101,39 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 85 | # set slice_size = `None` to disable `attention slicing` | 101 | # set slice_size = `None` to disable `attention slicing` |
| 86 | self.enable_attention_slicing(None) | 102 | self.enable_attention_slicing(None) |
| 87 | 103 | ||
| 104 | def embeddings_for_prompt(self, prompt: List[List[str]]): | ||
| 105 | text_embeddings = [] | ||
| 106 | |||
| 107 | for p in prompt: | ||
| 108 | inputs = self.tokenizer( | ||
| 109 | p, | ||
| 110 | padding="max_length", | ||
| 111 | max_length=self.tokenizer.model_max_length, | ||
| 112 | return_tensors="pt", | ||
| 113 | ) | ||
| 114 | input_ids = inputs.input_ids | ||
| 115 | |||
| 116 | if input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
| 117 | removed_text = self.tokenizer.batch_decode(input_ids[:, self.tokenizer.model_max_length:]) | ||
| 118 | logger.warning( | ||
| 119 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
| 120 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
| 121 | ) | ||
| 122 | print(f"Too many tokens: {removed_text}") | ||
| 123 | input_ids = input_ids[:, : self.tokenizer.model_max_length] | ||
| 124 | |||
| 125 | embeddings = self.text_encoder(input_ids.to(self.device))[0] | ||
| 126 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 127 | text_embeddings.append(embeddings) | ||
| 128 | |||
| 129 | text_embeddings = torch.cat(text_embeddings) | ||
| 130 | return text_embeddings | ||
| 131 | |||
| 88 | @torch.no_grad() | 132 | @torch.no_grad() |
| 89 | def __call__( | 133 | def __call__( |
| 90 | self, | 134 | self, |
| 91 | prompt: Union[str, List[str]], | 135 | prompt: Union[str, List[str], List[List[str]]], |
| 92 | negative_prompt: Optional[Union[str, List[str]]] = None, | 136 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
| 93 | strength: float = 0.8, | 137 | strength: float = 0.8, |
| 94 | height: Optional[int] = 512, | 138 | height: Optional[int] = 512, |
| 95 | width: Optional[int] = 512, | 139 | width: Optional[int] = 512, |
| @@ -151,23 +195,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 151 | (nsfw) content, according to the `safety_checker`. | 195 | (nsfw) content, according to the `safety_checker`. |
| 152 | """ | 196 | """ |
| 153 | 197 | ||
| 154 | if isinstance(prompt, str): | 198 | prompt_size, prompt = normalize_prompt(prompt) |
| 155 | batch_size = 1 | 199 | batch_size = len(prompt) |
| 156 | elif isinstance(prompt, list): | 200 | _, negative_prompt = normalize_prompt(negative_prompt or "", batch_size, prompt_size) |
| 157 | batch_size = len(prompt) | ||
| 158 | else: | ||
| 159 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | ||
| 160 | 201 | ||
| 161 | if negative_prompt is None: | 202 | if len(negative_prompt) != batch_size: |
| 162 | negative_prompt = [""] * batch_size | 203 | raise ValueError( |
| 163 | elif isinstance(negative_prompt, str): | 204 | f"`prompt` and `negative_prompt` have to be the same length, but are {batch_size} and {len(negative_prompt)}") |
| 164 | negative_prompt = [negative_prompt] * batch_size | ||
| 165 | elif isinstance(negative_prompt, list): | ||
| 166 | if len(negative_prompt) != batch_size: | ||
| 167 | raise ValueError( | ||
| 168 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | ||
| 169 | else: | ||
| 170 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | ||
| 171 | 205 | ||
| 172 | if height % 8 != 0 or width % 8 != 0: | 206 | if height % 8 != 0 or width % 8 != 0: |
| 173 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 207 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| @@ -179,23 +213,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 179 | self.scheduler.set_timesteps(num_inference_steps) | 213 | self.scheduler.set_timesteps(num_inference_steps) |
| 180 | 214 | ||
| 181 | # get prompt text embeddings | 215 | # get prompt text embeddings |
| 182 | text_inputs = self.tokenizer( | 216 | text_embeddings = self.embeddings_for_prompt(prompt) |
| 183 | prompt, | ||
| 184 | padding="max_length", | ||
| 185 | max_length=self.tokenizer.model_max_length, | ||
| 186 | return_tensors="pt", | ||
| 187 | ) | ||
| 188 | text_input_ids = text_inputs.input_ids | ||
| 189 | |||
| 190 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | ||
| 191 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length:]) | ||
| 192 | logger.warning( | ||
| 193 | "The following part of your input was truncated because CLIP can only handle sequences up to" | ||
| 194 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | ||
| 195 | ) | ||
| 196 | print(f"Too many tokens: {removed_text}") | ||
| 197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | ||
| 198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | ||
| 199 | 217 | ||
| 200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 218 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 219 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| @@ -203,11 +221,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 203 | do_classifier_free_guidance = guidance_scale > 1.0 | 221 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 204 | # get unconditional embeddings for classifier free guidance | 222 | # get unconditional embeddings for classifier free guidance |
| 205 | if do_classifier_free_guidance: | 223 | if do_classifier_free_guidance: |
| 206 | max_length = text_input_ids.shape[-1] | 224 | uncond_embeddings = self.embeddings_for_prompt(negative_prompt) |
| 207 | uncond_input = self.tokenizer( | ||
| 208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | ||
| 209 | ) | ||
| 210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
| 211 | 225 | ||
| 212 | # For classifier free guidance, we need to do two forward passes. | 226 | # For classifier free guidance, we need to do two forward passes. |
| 213 | # Here we concatenate the unconditional and text embeddings into a single batch | 227 | # Here we concatenate the unconditional and text embeddings into a single batch |
