diff options
-rw-r--r-- | data/csv.py | 78 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 40 | ||||
-rw-r--r-- | train_dreambooth.py | 14 | ||||
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/util.py | 14 |
5 files changed, 102 insertions, 68 deletions
diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
41 | 41 | ||
42 | 42 | ||
43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): |
44 | item_order: list[int] = [] | 44 | bucket_items: list[int] = [] |
45 | item_buckets: list[int] = [] | 45 | bucket_assignments: list[int] = [] |
46 | buckets = [1.0] | 46 | buckets = [1.0] |
47 | 47 | ||
48 | for i in range(1, num_buckets + 1): | 48 | for i in range(1, num_buckets + 1): |
@@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
70 | if len(indices.shape) == 0: | 70 | if len(indices.shape) == 0: |
71 | indices = indices.unsqueeze(0) | 71 | indices = indices.unsqueeze(0) |
72 | 72 | ||
73 | item_order += [i] * len(indices) | 73 | bucket_items += [i] * len(indices) |
74 | item_buckets += indices | 74 | bucket_assignments += indices |
75 | 75 | ||
76 | return buckets.tolist(), item_order, item_buckets | 76 | return buckets.tolist(), bucket_items, bucket_assignments |
77 | 77 | ||
78 | 78 | ||
79 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
@@ -94,8 +94,8 @@ class VlpnDataModule(): | |||
94 | class_subdir: str = "cls", | 94 | class_subdir: str = "cls", |
95 | num_class_images: int = 1, | 95 | num_class_images: int = 1, |
96 | size: int = 768, | 96 | size: int = 768, |
97 | num_aspect_ratio_buckets: int = 0, | 97 | num_buckets: int = 0, |
98 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_buckets: bool = False, |
99 | dropout: float = 0, | 99 | dropout: float = 0, |
100 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
101 | template_key: str = "template", | 101 | template_key: str = "template", |
@@ -119,8 +119,8 @@ class VlpnDataModule(): | |||
119 | 119 | ||
120 | self.prompt_processor = prompt_processor | 120 | self.prompt_processor = prompt_processor |
121 | self.size = size | 121 | self.size = size |
122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_buckets = num_buckets |
123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_buckets = progressive_buckets |
124 | self.dropout = dropout | 124 | self.dropout = dropout |
125 | self.template_key = template_key | 125 | self.template_key = template_key |
126 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
@@ -207,15 +207,15 @@ class VlpnDataModule(): | |||
207 | 207 | ||
208 | train_dataset = VlpnDataset( | 208 | train_dataset = VlpnDataset( |
209 | self.data_train, self.prompt_processor, | 209 | self.data_train, self.prompt_processor, |
210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, | 210 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
211 | batch_size=self.batch_size, | 211 | batch_size=self.batch_size, generator=generator, |
212 | size=self.size, interpolation=self.interpolation, | 212 | size=self.size, interpolation=self.interpolation, |
213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
214 | ) | 214 | ) |
215 | 215 | ||
216 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
217 | self.data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
218 | batch_size=self.batch_size, | 218 | batch_size=self.batch_size, generator=generator, |
219 | size=self.size, interpolation=self.interpolation, | 219 | size=self.size, interpolation=self.interpolation, |
220 | ) | 220 | ) |
221 | 221 | ||
@@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): | |||
256 | self.interpolation = interpolations[interpolation] | 256 | self.interpolation = interpolations[interpolation] |
257 | self.generator = generator | 257 | self.generator = generator |
258 | 258 | ||
259 | buckets, item_order, item_buckets = generate_buckets( | 259 | buckets, bucket_items, bucket_assignments = generate_buckets( |
260 | [item.instance_image_path for item in items], | 260 | [item.instance_image_path for item in items], |
261 | size, | 261 | size, |
262 | num_buckets, | 262 | num_buckets, |
@@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): | |||
264 | ) | 264 | ) |
265 | 265 | ||
266 | self.buckets = torch.tensor(buckets) | 266 | self.buckets = torch.tensor(buckets) |
267 | self.item_order = torch.tensor(item_order) | 267 | self.bucket_items = torch.tensor(bucket_items) |
268 | self.item_buckets = torch.tensor(item_buckets) | 268 | self.bucket_assignments = torch.tensor(bucket_assignments) |
269 | self.bucket_item_range = torch.arange(len(bucket_items)) | ||
270 | |||
271 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | ||
269 | 272 | ||
270 | def __len__(self): | 273 | def __len__(self): |
271 | return len(self.item_buckets) | 274 | return self.length_ |
272 | 275 | ||
273 | def __iter__(self): | 276 | def __iter__(self): |
274 | worker_info = torch.utils.data.get_worker_info() | 277 | worker_info = torch.utils.data.get_worker_info() |
275 | 278 | ||
276 | if self.shuffle: | 279 | if self.shuffle: |
277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) | 280 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) |
278 | self.item_order = self.item_order[perm] | 281 | self.bucket_items = self.bucket_items[perm] |
279 | self.item_buckets = self.item_buckets[perm] | 282 | self.bucket_assignments = self.bucket_assignments[perm] |
280 | 283 | ||
281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) | ||
282 | bucket = -1 | ||
283 | image_transforms = None | 284 | image_transforms = None |
285 | |||
286 | mask = torch.ones_like(self.bucket_assignments, dtype=bool) | ||
287 | bucket = -1 | ||
284 | batch = [] | 288 | batch = [] |
285 | batch_size = self.batch_size | 289 | batch_size = self.batch_size |
286 | 290 | ||
@@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): | |||
289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | 293 | worker_batch = math.ceil(len(self) / worker_info.num_workers) |
290 | start = worker_info.id * worker_batch | 294 | start = worker_info.id * worker_batch |
291 | end = start + worker_batch | 295 | end = start + worker_batch |
292 | item_mask[:start] = False | 296 | mask[:start] = False |
293 | item_mask[end:] = False | 297 | mask[end:] = False |
294 | 298 | ||
295 | while item_mask.any(): | 299 | while mask.any(): |
296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | 300 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
301 | bucket_items = self.bucket_items[bucket_mask] | ||
297 | 302 | ||
298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | 303 | if len(batch) >= batch_size: |
299 | yield batch | 304 | yield batch |
300 | batch = [] | 305 | batch = [] |
301 | 306 | ||
302 | if len(item_indices) == 0: | 307 | if len(bucket_items) == 0: |
303 | bucket = self.item_buckets[item_mask][0] | 308 | if len(batch) != 0: |
309 | yield batch | ||
310 | batch = [] | ||
311 | |||
312 | bucket = self.bucket_assignments[mask][0] | ||
304 | ratio = self.buckets[bucket] | 313 | ratio = self.buckets[bucket] |
305 | width = self.size * ratio if ratio > 1 else self.size | 314 | width = self.size * ratio if ratio > 1 else self.size |
306 | height = self.size / ratio if ratio < 1 else self.size | 315 | height = self.size / ratio if ratio < 1 else self.size |
307 | 316 | ||
308 | image_transforms = transforms.Compose( | 317 | image_transforms = transforms.Compose( |
309 | [ | 318 | [ |
310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | 319 | transforms.Resize(self.size, interpolation=self.interpolation), |
311 | transforms.RandomCrop((height, width)), | 320 | transforms.RandomCrop((height, width)), |
312 | transforms.RandomHorizontalFlip(), | 321 | transforms.RandomHorizontalFlip(), |
313 | transforms.ToTensor(), | 322 | transforms.ToTensor(), |
@@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): | |||
315 | ] | 324 | ] |
316 | ) | 325 | ) |
317 | else: | 326 | else: |
318 | item_index = item_indices[0] | 327 | item_index = bucket_items[0] |
319 | item = self.items[item_index] | 328 | item = self.items[item_index] |
320 | item_mask[item_index] = False | 329 | mask[self.bucket_item_range[bucket_mask][0]] = False |
321 | 330 | ||
322 | example = {} | 331 | example = {} |
323 | 332 | ||
324 | example["prompts"] = keywords_to_prompt(item.prompt) | 333 | example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) |
325 | example["cprompts"] = item.cprompt | 334 | example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) |
326 | example["nprompts"] = item.nprompt | ||
327 | 335 | ||
328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 336 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 337 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
@@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): | |||
332 | 340 | ||
333 | if self.num_class_images != 0: | 341 | if self.num_class_images != 0: |
334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 342 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | 343 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) |
336 | 344 | ||
337 | batch.append(example) | 345 | batch.append(example) |
338 | 346 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
79 | unet=unet, | 79 | unet=unet, |
80 | scheduler=scheduler, | 80 | scheduler=scheduler, |
81 | ) | 81 | ) |
82 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | ||
82 | 83 | ||
83 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 84 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
84 | r""" | 85 | r""" |
@@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
160 | return torch.device(module._hf_hook.execution_device) | 161 | return torch.device(module._hf_hook.execution_device) |
161 | return self.device | 162 | return self.device |
162 | 163 | ||
163 | def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): | 164 | def check_inputs( |
164 | if isinstance(prompt, str): | 165 | self, |
166 | prompt: Union[str, List[str], List[int], List[List[int]]], | ||
167 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | ||
168 | width: Optional[int], | ||
169 | height: Optional[int], | ||
170 | strength: float, | ||
171 | callback_steps: Optional[int] | ||
172 | ): | ||
173 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | ||
165 | prompt = [prompt] | 174 | prompt = [prompt] |
166 | 175 | ||
167 | if negative_prompt is None: | 176 | if negative_prompt is None: |
168 | negative_prompt = "" | 177 | negative_prompt = "" |
169 | 178 | ||
170 | if isinstance(negative_prompt, str): | 179 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): |
171 | negative_prompt = [negative_prompt] * len(prompt) | 180 | negative_prompt = [negative_prompt] * len(prompt) |
172 | 181 | ||
173 | if not isinstance(prompt, list): | 182 | if not isinstance(prompt, list): |
@@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
196 | 205 | ||
197 | return prompt, negative_prompt | 206 | return prompt, negative_prompt |
198 | 207 | ||
199 | def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): | 208 | def encode_prompt( |
200 | text_input_ids = self.prompt_processor.get_input_ids(prompt) | 209 | self, |
210 | prompt: Union[List[str], List[List[int]]], | ||
211 | negative_prompt: Union[List[str], List[List[int]]], | ||
212 | num_images_per_prompt: int, | ||
213 | do_classifier_free_guidance: bool, | ||
214 | device | ||
215 | ): | ||
216 | text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt | ||
201 | text_input_ids *= num_images_per_prompt | 217 | text_input_ids *= num_images_per_prompt |
202 | 218 | ||
203 | if do_classifier_free_guidance: | 219 | if do_classifier_free_guidance: |
204 | unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) | 220 | unconditional_input_ids = self.prompt_processor.get_input_ids( |
221 | negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt | ||
205 | unconditional_input_ids *= num_images_per_prompt | 222 | unconditional_input_ids *= num_images_per_prompt |
206 | text_input_ids = unconditional_input_ids + text_input_ids | 223 | text_input_ids = unconditional_input_ids + text_input_ids |
207 | 224 | ||
@@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
314 | @torch.no_grad() | 331 | @torch.no_grad() |
315 | def __call__( | 332 | def __call__( |
316 | self, | 333 | self, |
317 | prompt: Union[str, List[str], List[List[str]]], | 334 | prompt: Union[str, List[str], List[int], List[List[int]]], |
318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 335 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
319 | num_images_per_prompt: Optional[int] = 1, | 336 | num_images_per_prompt: Optional[int] = 1, |
320 | strength: float = 0.8, | 337 | strength: float = 0.8, |
321 | height: Optional[int] = 768, | 338 | height: Optional[int] = None, |
322 | width: Optional[int] = 768, | 339 | width: Optional[int] = None, |
323 | num_inference_steps: Optional[int] = 50, | 340 | num_inference_steps: Optional[int] = 50, |
324 | guidance_scale: Optional[float] = 7.5, | 341 | guidance_scale: Optional[float] = 7.5, |
325 | eta: Optional[float] = 0.0, | 342 | eta: Optional[float] = 0.0, |
@@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
379 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 396 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
380 | (nsfw) content, according to the `safety_checker`. | 397 | (nsfw) content, according to the `safety_checker`. |
381 | """ | 398 | """ |
399 | # 0. Default height and width to unet | ||
400 | height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
401 | width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
382 | 402 | ||
383 | # 1. Check inputs. Raise error if not correct | 403 | # 1. Check inputs. Raise error if not correct |
384 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 404 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 42a7d0f..79eede6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -699,9 +699,9 @@ def main(): | |||
699 | return cond3 and cond4 | 699 | return cond3 and cond4 |
700 | 700 | ||
701 | def collate_fn(examples): | 701 | def collate_fn(examples): |
702 | prompts = [example["prompts"] for example in examples] | 702 | prompt_ids = [example["prompt_ids"] for example in examples] |
703 | cprompts = [example["cprompts"] for example in examples] | 703 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
704 | nprompts = [example["nprompts"] for example in examples] | 704 | |
705 | input_ids = [example["instance_prompt_ids"] for example in examples] | 705 | input_ids = [example["instance_prompt_ids"] for example in examples] |
706 | pixel_values = [example["instance_images"] for example in examples] | 706 | pixel_values = [example["instance_images"] for example in examples] |
707 | 707 | ||
@@ -713,16 +713,18 @@ def main(): | |||
713 | pixel_values = torch.stack(pixel_values) | 713 | pixel_values = torch.stack(pixel_values) |
714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
715 | 715 | ||
716 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
717 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
716 | inputs = prompt_processor.unify_input_ids(input_ids) | 718 | inputs = prompt_processor.unify_input_ids(input_ids) |
717 | 719 | ||
718 | batch = { | 720 | batch = { |
719 | "prompts": prompts, | 721 | "prompt_ids": prompts.input_ids, |
720 | "cprompts": cprompts, | 722 | "nprompt_ids": nprompts.input_ids, |
721 | "nprompts": nprompts, | ||
722 | "input_ids": inputs.input_ids, | 723 | "input_ids": inputs.input_ids, |
723 | "pixel_values": pixel_values, | 724 | "pixel_values": pixel_values, |
724 | "attention_mask": inputs.attention_mask, | 725 | "attention_mask": inputs.attention_mask, |
725 | } | 726 | } |
727 | |||
726 | return batch | 728 | return batch |
727 | 729 | ||
728 | datamodule = VlpnDataModule( | 730 | datamodule = VlpnDataModule( |
diff --git a/train_ti.py b/train_ti.py index 727b591..323ef10 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -140,13 +140,13 @@ def parse_args(): | |||
140 | ), | 140 | ), |
141 | ) | 141 | ) |
142 | parser.add_argument( | 142 | parser.add_argument( |
143 | "--num_aspect_ratio_buckets", | 143 | "--num_buckets", |
144 | type=int, | 144 | type=int, |
145 | default=4, | 145 | default=4, |
146 | help="Number of buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--progressive_aspect_ratio_buckets", | 149 | "--progressive_buckets", |
150 | action="store_true", | 150 | action="store_true", |
151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
152 | ) | 152 | ) |
@@ -681,9 +681,9 @@ def main(): | |||
681 | return cond1 and cond3 and cond4 | 681 | return cond1 and cond3 and cond4 |
682 | 682 | ||
683 | def collate_fn(examples): | 683 | def collate_fn(examples): |
684 | prompts = [example["prompts"] for example in examples] | 684 | prompt_ids = [example["prompt_ids"] for example in examples] |
685 | cprompts = [example["cprompts"] for example in examples] | 685 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
686 | nprompts = [example["nprompts"] for example in examples] | 686 | |
687 | input_ids = [example["instance_prompt_ids"] for example in examples] | 687 | input_ids = [example["instance_prompt_ids"] for example in examples] |
688 | pixel_values = [example["instance_images"] for example in examples] | 688 | pixel_values = [example["instance_images"] for example in examples] |
689 | 689 | ||
@@ -695,16 +695,18 @@ def main(): | |||
695 | pixel_values = torch.stack(pixel_values) | 695 | pixel_values = torch.stack(pixel_values) |
696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
697 | 697 | ||
698 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
699 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
698 | inputs = prompt_processor.unify_input_ids(input_ids) | 700 | inputs = prompt_processor.unify_input_ids(input_ids) |
699 | 701 | ||
700 | batch = { | 702 | batch = { |
701 | "prompts": prompts, | 703 | "prompt_ids": prompts.input_ids, |
702 | "cprompts": cprompts, | 704 | "nprompt_ids": nprompts.input_ids, |
703 | "nprompts": nprompts, | ||
704 | "input_ids": inputs.input_ids, | 705 | "input_ids": inputs.input_ids, |
705 | "pixel_values": pixel_values, | 706 | "pixel_values": pixel_values, |
706 | "attention_mask": inputs.attention_mask, | 707 | "attention_mask": inputs.attention_mask, |
707 | } | 708 | } |
709 | |||
708 | return batch | 710 | return batch |
709 | 711 | ||
710 | datamodule = VlpnDataModule( | 712 | datamodule = VlpnDataModule( |
@@ -714,8 +716,8 @@ def main(): | |||
714 | class_subdir=args.class_image_dir, | 716 | class_subdir=args.class_image_dir, |
715 | num_class_images=args.num_class_images, | 717 | num_class_images=args.num_class_images, |
716 | size=args.resolution, | 718 | size=args.resolution, |
717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 719 | num_buckets=args.num_buckets, |
718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 720 | progressive_buckets=args.progressive_buckets, |
719 | dropout=args.tag_dropout, | 721 | dropout=args.tag_dropout, |
720 | template_key=args.train_data_template, | 722 | template_key=args.train_data_template, |
721 | valid_set_size=args.valid_set_size, | 723 | valid_set_size=args.valid_set_size, |
diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -73,20 +73,22 @@ class CheckpointerBase: | |||
73 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
74 | 74 | ||
75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
76 | prompts = [ | 76 | prompt_ids = [ |
77 | prompt | 77 | prompt |
78 | for batch in batches | 78 | for batch in batches |
79 | for prompt in batch["prompts"] | 79 | for prompt in batch["prompt_ids"] |
80 | ] | 80 | ] |
81 | nprompts = [ | 81 | nprompt_ids = [ |
82 | prompt | 82 | prompt |
83 | for batch in batches | 83 | for batch in batches |
84 | for prompt in batch["nprompts"] | 84 | for prompt in batch["nprompt_ids"] |
85 | ] | 85 | ] |
86 | 86 | ||
87 | for i in range(self.sample_batches): | 87 | for i in range(self.sample_batches): |
88 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 88 | start = i * self.sample_batch_size |
89 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 89 | end = (i + 1) * self.sample_batch_size |
90 | prompt = prompt_ids[start:end] | ||
91 | nprompt = nprompt_ids[start:end] | ||
90 | 92 | ||
91 | samples = pipeline( | 93 | samples = pipeline( |
92 | prompt=prompt, | 94 | prompt=prompt, |