summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 09:40:24 +0100
commit403f525d0c6900cc6844c0d2f4ecb385fc131969 (patch)
tree385c62ef44cc33abc3c5d4b2084c376551137c5f /data/csv.py
parentDon't use vector_dropout by default (diff)
downloadtextual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.gz
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.tar.bz2
textual-inversion-diff-403f525d0c6900cc6844c0d2f4ecb385fc131969.zip
Fixed reproducibility, more consistant validation
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index af36d9e..e901ab4 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -59,7 +59,7 @@ class CSVDataModule():
59 center_crop: bool = False, 59 center_crop: bool = False,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 seed: Optional[int] = None,
63 filter: Optional[Callable[[CSVDataItem], bool]] = None, 63 filter: Optional[Callable[[CSVDataItem], bool]] = None,
64 collate_fn=None, 64 collate_fn=None,
65 num_workers: int = 0 65 num_workers: int = 0
@@ -84,7 +84,7 @@ class CSVDataModule():
84 self.template_key = template_key 84 self.template_key = template_key
85 self.interpolation = interpolation 85 self.interpolation = interpolation
86 self.valid_set_size = valid_set_size 86 self.valid_set_size = valid_set_size
87 self.generator = generator 87 self.seed = seed
88 self.filter = filter 88 self.filter = filter
89 self.collate_fn = collate_fn 89 self.collate_fn = collate_fn
90 self.num_workers = num_workers 90 self.num_workers = num_workers
@@ -155,7 +155,11 @@ class CSVDataModule():
155 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
156 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
157 157
158 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) 158 generator = torch.Generator(device="cpu")
159 if self.seed is not None:
160 generator = generator.manual_seed(self.seed)
161
162 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
159 163
160 self.data_train = self.pad_items(data_train, self.num_class_images) 164 self.data_train = self.pad_items(data_train, self.num_class_images)
161 self.data_val = self.pad_items(data_val) 165 self.data_val = self.pad_items(data_val)