diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 15 |
1 files changed, 4 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py index 0ad36dc..4da5d64 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple): | |||
| 41 | prompt: list[str] | 41 | prompt: list[str] |
| 42 | cprompt: str | 42 | cprompt: str |
| 43 | nprompt: str | 43 | nprompt: str |
| 44 | mode: list[str] | ||
| 44 | 45 | ||
| 45 | 46 | ||
| 46 | class CSVDataModule(): | 47 | class CSVDataModule(): |
| @@ -56,7 +57,6 @@ class CSVDataModule(): | |||
| 56 | dropout: float = 0, | 57 | dropout: float = 0, |
| 57 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
| 58 | center_crop: bool = False, | 59 | center_crop: bool = False, |
| 59 | mode: Optional[str] = None, | ||
| 60 | template_key: str = "template", | 60 | template_key: str = "template", |
| 61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
| 62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
| @@ -81,7 +81,6 @@ class CSVDataModule(): | |||
| 81 | self.repeats = repeats | 81 | self.repeats = repeats |
| 82 | self.dropout = dropout | 82 | self.dropout = dropout |
| 83 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
| 84 | self.mode = mode | ||
| 85 | self.template_key = template_key | 84 | self.template_key = template_key |
| 86 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
| 87 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
| @@ -113,6 +112,7 @@ class CSVDataModule(): | |||
| 113 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 114 | expansions | 113 | expansions |
| 115 | )), | 114 | )), |
| 115 | item["mode"].split(", ") if "mode" in item else [] | ||
| 116 | ) | 116 | ) |
| 117 | for item in data | 117 | for item in data |
| 118 | ] | 118 | ] |
| @@ -133,6 +133,7 @@ class CSVDataModule(): | |||
| 133 | item.prompt, | 133 | item.prompt, |
| 134 | item.cprompt, | 134 | item.cprompt, |
| 135 | item.nprompt, | 135 | item.nprompt, |
| 136 | item.mode, | ||
| 136 | ) | 137 | ) |
| 137 | for item in items | 138 | for item in items |
| 138 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
| @@ -145,20 +146,12 @@ class CSVDataModule(): | |||
| 145 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 146 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
| 146 | items = metadata["items"] if "items" in metadata else [] | 147 | items = metadata["items"] if "items" in metadata else [] |
| 147 | 148 | ||
| 148 | if self.mode is not None: | ||
| 149 | items = [ | ||
| 150 | item | ||
| 151 | for item in items | ||
| 152 | if "mode" in item and self.mode in item["mode"].split(", ") | ||
| 153 | ] | ||
| 154 | items = self.prepare_items(template, expansions, items) | 149 | items = self.prepare_items(template, expansions, items) |
| 155 | items = self.filter_items(items) | 150 | items = self.filter_items(items) |
| 156 | 151 | ||
| 157 | num_images = len(items) | 152 | num_images = len(items) |
| 158 | 153 | ||
| 159 | valid_set_size = int(num_images * 0.1) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) |
| 160 | if self.valid_set_size: | ||
| 161 | valid_set_size = min(valid_set_size, self.valid_set_size) | ||
| 162 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
| 163 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
| 164 | 157 | ||
