diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py index 4da5d64..803271b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple): | |||
41 | prompt: list[str] | 41 | prompt: list[str] |
42 | cprompt: str | 42 | cprompt: str |
43 | nprompt: str | 43 | nprompt: str |
44 | mode: list[str] | 44 | collection: list[str] |
45 | 45 | ||
46 | 46 | ||
47 | class CSVDataModule(): | 47 | class CSVDataModule(): |
48 | def __init__( | 48 | def __init__( |
49 | self, | 49 | self, |
50 | batch_size: int, | 50 | batch_size: int, |
51 | data_file: str, | 51 | data_file: str, |
52 | prompt_processor: PromptProcessor, | 52 | prompt_processor: PromptProcessor, |
53 | class_subdir: str = "cls", | 53 | class_subdir: str = "cls", |
54 | num_class_images: int = 1, | 54 | num_class_images: int = 1, |
55 | size: int = 768, | 55 | size: int = 768, |
56 | repeats: int = 1, | 56 | repeats: int = 1, |
57 | dropout: float = 0, | 57 | dropout: float = 0, |
58 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
64 | collate_fn=None, | 64 | collate_fn=None, |
65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
66 | ): | 66 | ): |
67 | super().__init__() | 67 | super().__init__() |
68 | 68 | ||
@@ -112,7 +112,7 @@ class CSVDataModule(): | |||
112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
113 | expansions | 113 | expansions |
114 | )), | 114 | )), |
115 | item["mode"].split(", ") if "mode" in item else [] | 115 | item["collection"].split(", ") if "collection" in item else [] |
116 | ) | 116 | ) |
117 | for item in data | 117 | for item in data |
118 | ] | 118 | ] |
@@ -133,7 +133,7 @@ class CSVDataModule(): | |||
133 | item.prompt, | 133 | item.prompt, |
134 | item.cprompt, | 134 | item.cprompt, |
135 | item.nprompt, | 135 | item.nprompt, |
136 | item.mode, | 136 | item.collection, |
137 | ) | 137 | ) |
138 | for item in items | 138 | for item in items |
139 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |