diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 14 | ||||
-rw-r--r-- | data/keywords.py | 13 |
2 files changed, 21 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py index 3af9925..c5e7aef 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,12 +1,13 @@ | |||
1 | import math | 1 | import math |
2 | import torch | ||
3 | import json | 2 | import json |
4 | from functools import partial | 3 | from functools import partial |
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | from typing import NamedTuple, Optional, Union, Callable | 5 | from typing import NamedTuple, Optional, Union, Callable |
7 | 6 | ||
8 | from PIL import Image | 7 | from PIL import Image |
8 | import numpy as np | ||
9 | 9 | ||
10 | import torch | ||
10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 11 | from torch.utils.data import IterableDataset, DataLoader, random_split |
11 | from torchvision import transforms | 12 | from torchvision import transforms |
12 | from transformers import CLIPTokenizer | 13 | from transformers import CLIPTokenizer |
@@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple): | |||
141 | nprompt: str | 142 | nprompt: str |
142 | collection: list[str] | 143 | collection: list[str] |
143 | 144 | ||
144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): | 145 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): |
145 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) | 146 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) |
146 | 147 | ||
147 | 148 | ||
148 | def keyword_filter( | 149 | def keyword_filter( |
@@ -193,6 +194,7 @@ class VlpnDataModule(): | |||
193 | train_set_pad: Optional[int] = None, | 194 | train_set_pad: Optional[int] = None, |
194 | valid_set_pad: Optional[int] = None, | 195 | valid_set_pad: Optional[int] = None, |
195 | generator: Optional[torch.Generator] = None, | 196 | generator: Optional[torch.Generator] = None, |
197 | npgenerator: Optional[np.random.Generator] = None, | ||
196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 198 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
197 | dtype: torch.dtype = torch.float32, | 199 | dtype: torch.dtype = torch.float32, |
198 | ): | 200 | ): |
@@ -228,6 +230,7 @@ class VlpnDataModule(): | |||
228 | self.batch_size = batch_size | 230 | self.batch_size = batch_size |
229 | self.dtype = dtype | 231 | self.dtype = dtype |
230 | self.generator = generator | 232 | self.generator = generator |
233 | self.npgenerator = npgenerator or np.random.default_rng() | ||
231 | 234 | ||
232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 235 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
233 | tpl_image = template["image"] if "image" in template else "{}" | 236 | tpl_image = template["image"] if "image" in template else "{}" |
@@ -297,6 +300,7 @@ class VlpnDataModule(): | |||
297 | 300 | ||
298 | items = self.prepare_items(template, expansions, items) | 301 | items = self.prepare_items(template, expansions, items) |
299 | items = self.filter_items(items) | 302 | items = self.filter_items(items) |
303 | self.npgenerator.shuffle(items) | ||
300 | 304 | ||
301 | num_images = len(items) | 305 | num_images = len(items) |
302 | 306 | ||
@@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset): | |||
370 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
371 | color_jitter: bool = True, | 375 | color_jitter: bool = True, |
372 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
377 | npgenerator: Optional[np.random.Generator] = None, | ||
373 | ): | 378 | ): |
374 | self.items = items | 379 | self.items = items |
375 | self.batch_size = batch_size | 380 | self.batch_size = batch_size |
@@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset): | |||
383 | self.interpolation = interpolations[interpolation] | 388 | self.interpolation = interpolations[interpolation] |
384 | self.color_jitter = color_jitter | 389 | self.color_jitter = color_jitter |
385 | self.generator = generator | 390 | self.generator = generator |
391 | self.npgenerator = npgenerator | ||
386 | 392 | ||
387 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 393 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
388 | [item.instance_image_path for item in self.items], | 394 | [item.instance_image_path for item in self.items], |
@@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset): | |||
477 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 483 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
478 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 484 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
479 | 485 | ||
480 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) | 486 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) |
481 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 487 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
482 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 488 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
483 | 489 | ||
diff --git a/data/keywords.py b/data/keywords.py index 629006d..8632d67 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -1,14 +1,23 @@ | |||
1 | from typing import Optional | ||
2 | |||
1 | import numpy as np | 3 | import numpy as np |
2 | 4 | ||
3 | 5 | ||
4 | def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: | 6 | def keywords_to_str( |
7 | keywords: list[str], | ||
8 | undroppable_keywords: list[str] = [], | ||
9 | dropout: float = 0, | ||
10 | shuffle: bool = False, | ||
11 | npgenerator: Optional[np.random.Generator] = None | ||
12 | ) -> str: | ||
5 | if dropout != 0: | 13 | if dropout != 0: |
6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
7 | else: | 15 | else: |
8 | keywords = keywords.copy() | 16 | keywords = keywords.copy() |
9 | keywords += undroppable_keywords | 17 | keywords += undroppable_keywords |
10 | if shuffle: | 18 | if shuffle: |
11 | np.random.shuffle(keywords) | 19 | npgenerator = npgenerator or np.random.default_rng() |
20 | npgenerator.shuffle(keywords) | ||
12 | return ", ".join(keywords) | 21 | return ", ".join(keywords) |
13 | 22 | ||
14 | 23 | ||