summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py
index 3af9925..c5e7aef 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,12 +1,13 @@
1import math 1import math
2import torch
3import json 2import json
4from functools import partial 3from functools import partial
5from pathlib import Path 4from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 5from typing import NamedTuple, Optional, Union, Callable
7 6
8from PIL import Image 7from PIL import Image
8import numpy as np
9 9
10import torch
10from torch.utils.data import IterableDataset, DataLoader, random_split 11from torch.utils.data import IterableDataset, DataLoader, random_split
11from torchvision import transforms 12from torchvision import transforms
12from transformers import CLIPTokenizer 13from transformers import CLIPTokenizer
@@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple):
141 nprompt: str 142 nprompt: str
142 collection: list[str] 143 collection: list[str]
143 144
144 def full_prompt(self, dropout: float = 0, shuffle: bool = False): 145 def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None):
145 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) 146 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator)
146 147
147 148
148def keyword_filter( 149def keyword_filter(
@@ -193,6 +194,7 @@ class VlpnDataModule():
193 train_set_pad: Optional[int] = None, 194 train_set_pad: Optional[int] = None,
194 valid_set_pad: Optional[int] = None, 195 valid_set_pad: Optional[int] = None,
195 generator: Optional[torch.Generator] = None, 196 generator: Optional[torch.Generator] = None,
197 npgenerator: Optional[np.random.Generator] = None,
196 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 198 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 dtype: torch.dtype = torch.float32, 199 dtype: torch.dtype = torch.float32,
198 ): 200 ):
@@ -228,6 +230,7 @@ class VlpnDataModule():
228 self.batch_size = batch_size 230 self.batch_size = batch_size
229 self.dtype = dtype 231 self.dtype = dtype
230 self.generator = generator 232 self.generator = generator
233 self.npgenerator = npgenerator or np.random.default_rng()
231 234
232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 235 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
233 tpl_image = template["image"] if "image" in template else "{}" 236 tpl_image = template["image"] if "image" in template else "{}"
@@ -297,6 +300,7 @@ class VlpnDataModule():
297 300
298 items = self.prepare_items(template, expansions, items) 301 items = self.prepare_items(template, expansions, items)
299 items = self.filter_items(items) 302 items = self.filter_items(items)
303 self.npgenerator.shuffle(items)
300 304
301 num_images = len(items) 305 num_images = len(items)
302 306
@@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset):
370 interpolation: str = "bicubic", 374 interpolation: str = "bicubic",
371 color_jitter: bool = True, 375 color_jitter: bool = True,
372 generator: Optional[torch.Generator] = None, 376 generator: Optional[torch.Generator] = None,
377 npgenerator: Optional[np.random.Generator] = None,
373 ): 378 ):
374 self.items = items 379 self.items = items
375 self.batch_size = batch_size 380 self.batch_size = batch_size
@@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset):
383 self.interpolation = interpolations[interpolation] 388 self.interpolation = interpolations[interpolation]
384 self.color_jitter = color_jitter 389 self.color_jitter = color_jitter
385 self.generator = generator 390 self.generator = generator
391 self.npgenerator = npgenerator
386 392
387 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 393 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
388 [item.instance_image_path for item in self.items], 394 [item.instance_image_path for item in self.items],
@@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset):
477 example["prompt_ids"] = self.get_input_ids(item.full_prompt()) 483 example["prompt_ids"] = self.get_input_ids(item.full_prompt())
478 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 484 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
479 485
480 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) 486 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator))
481 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 487 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
482 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 488 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
483 489