diff options
| -rw-r--r-- | data/csv.py | 78 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 40 | ||||
| -rw-r--r-- | train_dreambooth.py | 14 | ||||
| -rw-r--r-- | train_ti.py | 24 | ||||
| -rw-r--r-- | training/util.py | 14 |
5 files changed, 102 insertions, 68 deletions
diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
| 41 | 41 | ||
| 42 | 42 | ||
| 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): |
| 44 | item_order: list[int] = [] | 44 | bucket_items: list[int] = [] |
| 45 | item_buckets: list[int] = [] | 45 | bucket_assignments: list[int] = [] |
| 46 | buckets = [1.0] | 46 | buckets = [1.0] |
| 47 | 47 | ||
| 48 | for i in range(1, num_buckets + 1): | 48 | for i in range(1, num_buckets + 1): |
| @@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
| 70 | if len(indices.shape) == 0: | 70 | if len(indices.shape) == 0: |
| 71 | indices = indices.unsqueeze(0) | 71 | indices = indices.unsqueeze(0) |
| 72 | 72 | ||
| 73 | item_order += [i] * len(indices) | 73 | bucket_items += [i] * len(indices) |
| 74 | item_buckets += indices | 74 | bucket_assignments += indices |
| 75 | 75 | ||
| 76 | return buckets.tolist(), item_order, item_buckets | 76 | return buckets.tolist(), bucket_items, bucket_assignments |
| 77 | 77 | ||
| 78 | 78 | ||
| 79 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
| @@ -94,8 +94,8 @@ class VlpnDataModule(): | |||
| 94 | class_subdir: str = "cls", | 94 | class_subdir: str = "cls", |
| 95 | num_class_images: int = 1, | 95 | num_class_images: int = 1, |
| 96 | size: int = 768, | 96 | size: int = 768, |
| 97 | num_aspect_ratio_buckets: int = 0, | 97 | num_buckets: int = 0, |
| 98 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_buckets: bool = False, |
| 99 | dropout: float = 0, | 99 | dropout: float = 0, |
| 100 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
| 101 | template_key: str = "template", | 101 | template_key: str = "template", |
| @@ -119,8 +119,8 @@ class VlpnDataModule(): | |||
| 119 | 119 | ||
| 120 | self.prompt_processor = prompt_processor | 120 | self.prompt_processor = prompt_processor |
| 121 | self.size = size | 121 | self.size = size |
| 122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_buckets = num_buckets |
| 123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_buckets = progressive_buckets |
| 124 | self.dropout = dropout | 124 | self.dropout = dropout |
| 125 | self.template_key = template_key | 125 | self.template_key = template_key |
| 126 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
| @@ -207,15 +207,15 @@ class VlpnDataModule(): | |||
| 207 | 207 | ||
| 208 | train_dataset = VlpnDataset( | 208 | train_dataset = VlpnDataset( |
| 209 | self.data_train, self.prompt_processor, | 209 | self.data_train, self.prompt_processor, |
| 210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, | 210 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 211 | batch_size=self.batch_size, | 211 | batch_size=self.batch_size, generator=generator, |
| 212 | size=self.size, interpolation=self.interpolation, | 212 | size=self.size, interpolation=self.interpolation, |
| 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
| 214 | ) | 214 | ) |
| 215 | 215 | ||
| 216 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
| 217 | self.data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
| 218 | batch_size=self.batch_size, | 218 | batch_size=self.batch_size, generator=generator, |
| 219 | size=self.size, interpolation=self.interpolation, | 219 | size=self.size, interpolation=self.interpolation, |
| 220 | ) | 220 | ) |
| 221 | 221 | ||
| @@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): | |||
| 256 | self.interpolation = interpolations[interpolation] | 256 | self.interpolation = interpolations[interpolation] |
| 257 | self.generator = generator | 257 | self.generator = generator |
| 258 | 258 | ||
| 259 | buckets, item_order, item_buckets = generate_buckets( | 259 | buckets, bucket_items, bucket_assignments = generate_buckets( |
| 260 | [item.instance_image_path for item in items], | 260 | [item.instance_image_path for item in items], |
| 261 | size, | 261 | size, |
| 262 | num_buckets, | 262 | num_buckets, |
| @@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): | |||
| 264 | ) | 264 | ) |
| 265 | 265 | ||
| 266 | self.buckets = torch.tensor(buckets) | 266 | self.buckets = torch.tensor(buckets) |
| 267 | self.item_order = torch.tensor(item_order) | 267 | self.bucket_items = torch.tensor(bucket_items) |
| 268 | self.item_buckets = torch.tensor(item_buckets) | 268 | self.bucket_assignments = torch.tensor(bucket_assignments) |
| 269 | self.bucket_item_range = torch.arange(len(bucket_items)) | ||
| 270 | |||
| 271 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | ||
| 269 | 272 | ||
| 270 | def __len__(self): | 273 | def __len__(self): |
| 271 | return len(self.item_buckets) | 274 | return self.length_ |
| 272 | 275 | ||
| 273 | def __iter__(self): | 276 | def __iter__(self): |
| 274 | worker_info = torch.utils.data.get_worker_info() | 277 | worker_info = torch.utils.data.get_worker_info() |
| 275 | 278 | ||
| 276 | if self.shuffle: | 279 | if self.shuffle: |
| 277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) | 280 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) |
| 278 | self.item_order = self.item_order[perm] | 281 | self.bucket_items = self.bucket_items[perm] |
| 279 | self.item_buckets = self.item_buckets[perm] | 282 | self.bucket_assignments = self.bucket_assignments[perm] |
| 280 | 283 | ||
| 281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) | ||
| 282 | bucket = -1 | ||
| 283 | image_transforms = None | 284 | image_transforms = None |
| 285 | |||
| 286 | mask = torch.ones_like(self.bucket_assignments, dtype=bool) | ||
| 287 | bucket = -1 | ||
| 284 | batch = [] | 288 | batch = [] |
| 285 | batch_size = self.batch_size | 289 | batch_size = self.batch_size |
| 286 | 290 | ||
| @@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): | |||
| 289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | 293 | worker_batch = math.ceil(len(self) / worker_info.num_workers) |
| 290 | start = worker_info.id * worker_batch | 294 | start = worker_info.id * worker_batch |
| 291 | end = start + worker_batch | 295 | end = start + worker_batch |
| 292 | item_mask[:start] = False | 296 | mask[:start] = False |
| 293 | item_mask[end:] = False | 297 | mask[end:] = False |
| 294 | 298 | ||
| 295 | while item_mask.any(): | 299 | while mask.any(): |
| 296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | 300 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
| 301 | bucket_items = self.bucket_items[bucket_mask] | ||
| 297 | 302 | ||
| 298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | 303 | if len(batch) >= batch_size: |
| 299 | yield batch | 304 | yield batch |
| 300 | batch = [] | 305 | batch = [] |
| 301 | 306 | ||
| 302 | if len(item_indices) == 0: | 307 | if len(bucket_items) == 0: |
| 303 | bucket = self.item_buckets[item_mask][0] | 308 | if len(batch) != 0: |
| 309 | yield batch | ||
| 310 | batch = [] | ||
| 311 | |||
| 312 | bucket = self.bucket_assignments[mask][0] | ||
| 304 | ratio = self.buckets[bucket] | 313 | ratio = self.buckets[bucket] |
| 305 | width = self.size * ratio if ratio > 1 else self.size | 314 | width = self.size * ratio if ratio > 1 else self.size |
| 306 | height = self.size / ratio if ratio < 1 else self.size | 315 | height = self.size / ratio if ratio < 1 else self.size |
| 307 | 316 | ||
| 308 | image_transforms = transforms.Compose( | 317 | image_transforms = transforms.Compose( |
| 309 | [ | 318 | [ |
| 310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | 319 | transforms.Resize(self.size, interpolation=self.interpolation), |
| 311 | transforms.RandomCrop((height, width)), | 320 | transforms.RandomCrop((height, width)), |
| 312 | transforms.RandomHorizontalFlip(), | 321 | transforms.RandomHorizontalFlip(), |
| 313 | transforms.ToTensor(), | 322 | transforms.ToTensor(), |
| @@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): | |||
| 315 | ] | 324 | ] |
| 316 | ) | 325 | ) |
| 317 | else: | 326 | else: |
| 318 | item_index = item_indices[0] | 327 | item_index = bucket_items[0] |
| 319 | item = self.items[item_index] | 328 | item = self.items[item_index] |
| 320 | item_mask[item_index] = False | 329 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| 321 | 330 | ||
| 322 | example = {} | 331 | example = {} |
| 323 | 332 | ||
| 324 | example["prompts"] = keywords_to_prompt(item.prompt) | 333 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) |
| 325 | example["cprompts"] = item.cprompt | 334 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) |
| 326 | example["nprompts"] = item.nprompt | ||
| 327 | 335 | ||
| 328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 336 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 337 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
| @@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): | |||
| 332 | 340 | ||
| 333 | if self.num_class_images != 0: | 341 | if self.num_class_images != 0: |
| 334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 342 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
| 335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | 343 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) |
| 336 | 344 | ||
| 337 | batch.append(example) | 345 | batch.append(example) |
| 338 | 346 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 79 | unet=unet, | 79 | unet=unet, |
| 80 | scheduler=scheduler, | 80 | scheduler=scheduler, |
| 81 | ) | 81 | ) |
| 82 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
| 82 | 83 | ||
| 83 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 84 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 84 | r""" | 85 | r""" |
| @@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 160 | return torch.device(module._hf_hook.execution_device) | 161 | return torch.device(module._hf_hook.execution_device) |
| 161 | return self.device | 162 | return self.device |
| 162 | 163 | ||
| 163 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | 164 | def check_inputs( |
| 164 | if isinstance(prompt, str): | 165 | self, |
| 166 | prompt: Union[str, List[str], List[int], List[List[int]]], | ||
| 167 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | ||
| 168 | width: Optional[int], | ||
| 169 | height: Optional[int], | ||
| 170 | strength: float, | ||
| 171 | callback_steps: Optional[int] | ||
| 172 | ): | ||
| 173 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | ||
| 165 | prompt = [prompt] | 174 | prompt = [prompt] |
| 166 | 175 | ||
| 167 | if negative_prompt is None: | 176 | if negative_prompt is None: |
| 168 | negative_prompt = "" | 177 | negative_prompt = "" |
| 169 | 178 | ||
| 170 | if isinstance(negative_prompt, str): | 179 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): |
| 171 | negative_prompt = [negative_prompt] * len(prompt) | 180 | negative_prompt = [negative_prompt] * len(prompt) |
| 172 | 181 | ||
| 173 | if not isinstance(prompt, list): | 182 | if not isinstance(prompt, list): |
| @@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 196 | 205 | ||
| 197 | return prompt, negative_prompt | 206 | return prompt, negative_prompt |
| 198 | 207 | ||
| 199 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): | 208 | def encode_prompt( |
| 200 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 209 | self, |
| 210 | prompt: Union[List[str], List[List[int]]], | ||
| 211 | negative_prompt: Union[List[str], List[List[int]]], | ||
| 212 | num_images_per_prompt: int, | ||
| 213 | do_classifier_free_guidance: bool, | ||
| 214 | device | ||
| 215 | ): | ||
| 216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | ||
| 201 | text_input_ids *= num_images_per_prompt | 217 | text_input_ids *= num_images_per_prompt |
| 202 | 218 | ||
| 203 | if do_classifier_free_guidance: | 219 | if do_classifier_free_guidance: |
| 204 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( |
| 221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | ||
| 205 | unconditional_input_ids *= num_images_per_prompt | 222 | unconditional_input_ids *= num_images_per_prompt |
| 206 | text_input_ids = unconditional_input_ids + text_input_ids | 223 | text_input_ids = unconditional_input_ids + text_input_ids |
| 207 | 224 | ||
| @@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 314 | @torch.no_grad() | 331 | @torch.no_grad() |
| 315 | def __call__( | 332 | def __call__( |
| 316 | self, | 333 | self, |
| 317 | prompt: Union[str, List[str], List[List[str]]], | 334 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 335 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
| 319 | num_images_per_prompt: Optional[int] = 1, | 336 | num_images_per_prompt: Optional[int] = 1, |
| 320 | strength: float = 0.8, | 337 | strength: float = 0.8, |
| 321 | height: Optional[int] = 768, | 338 | height: Optional[int] = None, |
| 322 | width: Optional[int] = 768, | 339 | width: Optional[int] = None, |
| 323 | num_inference_steps: Optional[int] = 50, | 340 | num_inference_steps: Optional[int] = 50, |
| 324 | guidance_scale: Optional[float] = 7.5, | 341 | guidance_scale: Optional[float] = 7.5, |
| 325 | eta: Optional[float] = 0.0, | 342 | eta: Optional[float] = 0.0, |
| @@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 379 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 396 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
| 380 | (nsfw) content, according to the `safety_checker`. | 397 | (nsfw) content, according to the `safety_checker`. |
| 381 | """ | 398 | """ |
| 399 | # 0. Default height and width to unet | ||
| 400 | height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
| 401 | width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
| 382 | 402 | ||
| 383 | # 1. Check inputs. Raise error if not correct | 403 | # 1. Check inputs. Raise error if not correct |
| 384 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 404 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 42a7d0f..79eede6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -699,9 +699,9 @@ def main(): | |||
| 699 | return cond3 and cond4 | 699 | return cond3 and cond4 |
| 700 | 700 | ||
| 701 | def collate_fn(examples): | 701 | def collate_fn(examples): |
| 702 | prompts = [example["prompts"] for example in examples] | 702 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 703 | cprompts = [example["cprompts"] for example in examples] | 703 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 704 | nprompts = [example["nprompts"] for example in examples] | 704 | |
| 705 | input_ids = [example["instance_prompt_ids"] for example in examples] | 705 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 706 | pixel_values = [example["instance_images"] for example in examples] | 706 | pixel_values = [example["instance_images"] for example in examples] |
| 707 | 707 | ||
| @@ -713,16 +713,18 @@ def main(): | |||
| 713 | pixel_values = torch.stack(pixel_values) | 713 | pixel_values = torch.stack(pixel_values) |
| 714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 715 | 715 | ||
| 716 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 717 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 716 | inputs = prompt_processor.unify_input_ids(input_ids) | 718 | inputs = prompt_processor.unify_input_ids(input_ids) |
| 717 | 719 | ||
| 718 | batch = { | 720 | batch = { |
| 719 | "prompts": prompts, | 721 | "prompt_ids": prompts.input_ids, |
| 720 | "cprompts": cprompts, | 722 | "nprompt_ids": nprompts.input_ids, |
| 721 | "nprompts": nprompts, | ||
| 722 | "input_ids": inputs.input_ids, | 723 | "input_ids": inputs.input_ids, |
| 723 | "pixel_values": pixel_values, | 724 | "pixel_values": pixel_values, |
| 724 | "attention_mask": inputs.attention_mask, | 725 | "attention_mask": inputs.attention_mask, |
| 725 | } | 726 | } |
| 727 | |||
| 726 | return batch | 728 | return batch |
| 727 | 729 | ||
| 728 | datamodule = VlpnDataModule( | 730 | datamodule = VlpnDataModule( |
diff --git a/train_ti.py b/train_ti.py index 727b591..323ef10 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -140,13 +140,13 @@ def parse_args(): | |||
| 140 | ), | 140 | ), |
| 141 | ) | 141 | ) |
| 142 | parser.add_argument( | 142 | parser.add_argument( |
| 143 | "--num_aspect_ratio_buckets", | 143 | "--num_buckets", |
| 144 | type=int, | 144 | type=int, |
| 145 | default=4, | 145 | default=4, |
| 146 | help="Number of buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--progressive_aspect_ratio_buckets", | 149 | "--progressive_buckets", |
| 150 | action="store_true", | 150 | action="store_true", |
| 151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
| 152 | ) | 152 | ) |
| @@ -681,9 +681,9 @@ def main(): | |||
| 681 | return cond1 and cond3 and cond4 | 681 | return cond1 and cond3 and cond4 |
| 682 | 682 | ||
| 683 | def collate_fn(examples): | 683 | def collate_fn(examples): |
| 684 | prompts = [example["prompts"] for example in examples] | 684 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 685 | cprompts = [example["cprompts"] for example in examples] | 685 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 686 | nprompts = [example["nprompts"] for example in examples] | 686 | |
| 687 | input_ids = [example["instance_prompt_ids"] for example in examples] | 687 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 688 | pixel_values = [example["instance_images"] for example in examples] | 688 | pixel_values = [example["instance_images"] for example in examples] |
| 689 | 689 | ||
| @@ -695,16 +695,18 @@ def main(): | |||
| 695 | pixel_values = torch.stack(pixel_values) | 695 | pixel_values = torch.stack(pixel_values) |
| 696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 697 | 697 | ||
| 698 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 699 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 698 | inputs = prompt_processor.unify_input_ids(input_ids) | 700 | inputs = prompt_processor.unify_input_ids(input_ids) |
| 699 | 701 | ||
| 700 | batch = { | 702 | batch = { |
| 701 | "prompts": prompts, | 703 | "prompt_ids": prompts.input_ids, |
| 702 | "cprompts": cprompts, | 704 | "nprompt_ids": nprompts.input_ids, |
| 703 | "nprompts": nprompts, | ||
| 704 | "input_ids": inputs.input_ids, | 705 | "input_ids": inputs.input_ids, |
| 705 | "pixel_values": pixel_values, | 706 | "pixel_values": pixel_values, |
| 706 | "attention_mask": inputs.attention_mask, | 707 | "attention_mask": inputs.attention_mask, |
| 707 | } | 708 | } |
| 709 | |||
| 708 | return batch | 710 | return batch |
| 709 | 711 | ||
| 710 | datamodule = VlpnDataModule( | 712 | datamodule = VlpnDataModule( |
| @@ -714,8 +716,8 @@ def main(): | |||
| 714 | class_subdir=args.class_image_dir, | 716 | class_subdir=args.class_image_dir, |
| 715 | num_class_images=args.num_class_images, | 717 | num_class_images=args.num_class_images, |
| 716 | size=args.resolution, | 718 | size=args.resolution, |
| 717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 719 | num_buckets=args.num_buckets, |
| 718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 720 | progressive_buckets=args.progressive_buckets, |
| 719 | dropout=args.tag_dropout, | 721 | dropout=args.tag_dropout, |
| 720 | template_key=args.train_data_template, | 722 | template_key=args.train_data_template, |
| 721 | valid_set_size=args.valid_set_size, | 723 | valid_set_size=args.valid_set_size, |
diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -73,20 +73,22 @@ class CheckpointerBase: | |||
| 73 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 74 | 74 | ||
| 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
| 76 | prompts = [ | 76 | prompt_ids = [ |
| 77 | prompt | 77 | prompt |
| 78 | for batch in batches | 78 | for batch in batches |
| 79 | for prompt in batch["prompts"] | 79 | for prompt in batch["prompt_ids"] |
| 80 | ] | 80 | ] |
| 81 | nprompts = [ | 81 | nprompt_ids = [ |
| 82 | prompt | 82 | prompt |
| 83 | for batch in batches | 83 | for batch in batches |
| 84 | for prompt in batch["nprompts"] | 84 | for prompt in batch["nprompt_ids"] |
| 85 | ] | 85 | ] |
| 86 | 86 | ||
| 87 | for i in range(self.sample_batches): | 87 | for i in range(self.sample_batches): |
| 88 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 88 | start = i * self.sample_batch_size |
| 89 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 89 | end = (i + 1) * self.sample_batch_size |
| 90 | prompt = prompt_ids[start:end] | ||
| 91 | nprompt = nprompt_ids[start:end] | ||
| 90 | 92 | ||
| 91 | samples = pipeline( | 93 | samples = pipeline( |
| 92 | prompt=prompt, | 94 | prompt=prompt, |
