diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py index e25dd3f..edce2b1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -49,7 +49,7 @@ class CSVDataModule(): | |||
| 49 | data_file: str, | 49 | data_file: str, |
| 50 | prompt_processor: PromptProcessor, | 50 | prompt_processor: PromptProcessor, |
| 51 | class_subdir: str = "cls", | 51 | class_subdir: str = "cls", |
| 52 | num_class_images: int = 100, | 52 | num_class_images: int = 1, |
| 53 | size: int = 512, | 53 | size: int = 512, |
| 54 | repeats: int = 1, | 54 | repeats: int = 1, |
| 55 | dropout: float = 0, | 55 | dropout: float = 0, |
| @@ -117,7 +117,7 @@ class CSVDataModule(): | |||
| 117 | return [item for item in items if self.filter(item)] | 117 | return [item for item in items if self.filter(item)] |
| 118 | 118 | ||
| 119 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 119 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: |
| 120 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | 120 | image_multiplier = max(num_class_images, 1) |
| 121 | 121 | ||
| 122 | return [ | 122 | return [ |
| 123 | CSVDataItem( | 123 | CSVDataItem( |
