From 1a0161f345191d78a19eec829f9d73b2c2c72f94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 09:44:12 +0200 Subject: Update --- data/csv.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 3af9925..c5e7aef 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,12 +1,13 @@ import math -import torch import json from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable from PIL import Image +import numpy as np +import torch from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer @@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple): nprompt: str collection: list[str] - def full_prompt(self, dropout: float = 0, shuffle: bool = False): - return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) + def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): + return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) def keyword_filter( @@ -193,6 +194,7 @@ class VlpnDataModule(): train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, generator: Optional[torch.Generator] = None, + npgenerator: Optional[np.random.Generator] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): @@ -228,6 +230,7 @@ class VlpnDataModule(): self.batch_size = batch_size self.dtype = dtype self.generator = generator + self.npgenerator = npgenerator or np.random.default_rng() def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" @@ -297,6 +300,7 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) + self.npgenerator.shuffle(items) num_images = len(items) @@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", color_jitter: bool = True, generator: Optional[torch.Generator] = None, + npgenerator: Optional[np.random.Generator] = None, ): self.items = items self.batch_size = batch_size @@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset): self.interpolation = interpolations[interpolation] self.color_jitter = color_jitter self.generator = generator + self.npgenerator = npgenerator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in self.items], @@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset): example["prompt_ids"] = self.get_input_ids(item.full_prompt()) example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) + example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) -- cgit v1.2.3-54-g00ecf