summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
committerVolpeon <git@volpeon.ink>2023-01-10 09:22:02 +0100
commit33e7d2ed37e32657ca94d92815043026c4cea7c0 (patch)
tree0af4d6ad0ba92a168e3ec17675147c76afe1baf0 /data/csv.py
parentEnable buckets for validation, fixed vaildation repeat arg (diff)
downloadtextual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.gz
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.tar.bz2
textual-inversion-diff-33e7d2ed37e32657ca94d92815043026c4cea7c0.zip
Added arg to disable tag shuffling
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index ed8e93d..9ad7dd6 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -122,6 +122,7 @@ class VlpnDataModule():
122 bucket_max_pixels: Optional[int] = None, 122 bucket_max_pixels: Optional[int] = None,
123 progressive_buckets: bool = False, 123 progressive_buckets: bool = False,
124 dropout: float = 0, 124 dropout: float = 0,
125 shuffle: bool = False,
125 interpolation: str = "bicubic", 126 interpolation: str = "bicubic",
126 template_key: str = "template", 127 template_key: str = "template",
127 valid_set_size: Optional[int] = None, 128 valid_set_size: Optional[int] = None,
@@ -150,6 +151,7 @@ class VlpnDataModule():
150 self.bucket_max_pixels = bucket_max_pixels 151 self.bucket_max_pixels = bucket_max_pixels
151 self.progressive_buckets = progressive_buckets 152 self.progressive_buckets = progressive_buckets
152 self.dropout = dropout 153 self.dropout = dropout
154 self.shuffle = shuffle
153 self.template_key = template_key 155 self.template_key = template_key
154 self.interpolation = interpolation 156 self.interpolation = interpolation
155 self.valid_set_size = valid_set_size 157 self.valid_set_size = valid_set_size
@@ -240,7 +242,7 @@ class VlpnDataModule():
240 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 242 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
241 batch_size=self.batch_size, generator=generator, 243 batch_size=self.batch_size, generator=generator,
242 size=self.size, interpolation=self.interpolation, 244 size=self.size, interpolation=self.interpolation,
243 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 245 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
244 ) 246 )
245 247
246 val_dataset = VlpnDataset( 248 val_dataset = VlpnDataset(