From b491a817088790219e052b86173e128c55b597f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:53:46 +0100 Subject: num_class_images is now class images per train image --- data/csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index e25dd3f..edce2b1 100644 --- a/data/csv.py +++ b/data/csv.py @@ -49,7 +49,7 @@ class CSVDataModule(): data_file: str, prompt_processor: PromptProcessor, class_subdir: str = "cls", - num_class_images: int = 100, + num_class_images: int = 1, size: int = 512, repeats: int = 1, dropout: float = 0, @@ -117,7 +117,7 @@ class CSVDataModule(): return [item for item in items if self.filter(item)] def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: - image_multiplier = max(math.ceil(num_class_images / len(items)), 1) + image_multiplier = max(num_class_images, 1) return [ CSVDataItem( -- cgit v1.2.3-70-g09d2