summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py
index e25dd3f..edce2b1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -49,7 +49,7 @@ class CSVDataModule():
49 data_file: str, 49 data_file: str,
50 prompt_processor: PromptProcessor, 50 prompt_processor: PromptProcessor,
51 class_subdir: str = "cls", 51 class_subdir: str = "cls",
52 num_class_images: int = 100, 52 num_class_images: int = 1,
53 size: int = 512, 53 size: int = 512,
54 repeats: int = 1, 54 repeats: int = 1,
55 dropout: float = 0, 55 dropout: float = 0,
@@ -117,7 +117,7 @@ class CSVDataModule():
117 return [item for item in items if self.filter(item)] 117 return [item for item in items if self.filter(item)]
118 118
119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
120 image_multiplier = max(math.ceil(num_class_images / len(items)), 1) 120 image_multiplier = max(num_class_images, 1)
121 121
122 return [ 122 return [
123 CSVDataItem( 123 CSVDataItem(