summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:53:46 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:53:46 +0100
commitb491a817088790219e052b86173e128c55b597f8 (patch)
tree49ea8480da78b002f36b15b52d1b9bc501837819 /data
parentImproved class prompt handling (diff)
downloadtextual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.tar.gz
textual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.tar.bz2
textual-inversion-diff-b491a817088790219e052b86173e128c55b597f8.zip
num_class_images is now class images per train image
Diffstat (limited to 'data')
-rw-r--r--data/csv.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/data/csv.py b/data/csv.py
index e25dd3f..edce2b1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -49,7 +49,7 @@ class CSVDataModule():
49 data_file: str, 49 data_file: str,
50 prompt_processor: PromptProcessor, 50 prompt_processor: PromptProcessor,
51 class_subdir: str = "cls", 51 class_subdir: str = "cls",
52 num_class_images: int = 100, 52 num_class_images: int = 1,
53 size: int = 512, 53 size: int = 512,
54 repeats: int = 1, 54 repeats: int = 1,
55 dropout: float = 0, 55 dropout: float = 0,
@@ -117,7 +117,7 @@ class CSVDataModule():
117 return [item for item in items if self.filter(item)] 117 return [item for item in items if self.filter(item)]
118 118
119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 119 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
120 image_multiplier = max(math.ceil(num_class_images / len(items)), 1) 120 image_multiplier = max(num_class_images, 1)
121 121
122 return [ 122 return [
123 CSVDataItem( 123 CSVDataItem(