diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 20ac992..053457b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
54 | dropout: float = 0, | 54 | dropout: float = 0, |
55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
57 | mode: Optional[str] = None, | ||
57 | template_key: str = "template", | 58 | template_key: str = "template", |
58 | valid_set_size: Optional[int] = None, | 59 | valid_set_size: Optional[int] = None, |
59 | generator: Optional[torch.Generator] = None, | 60 | generator: Optional[torch.Generator] = None, |
@@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
80 | self.repeats = repeats | 81 | self.repeats = repeats |
81 | self.dropout = dropout | 82 | self.dropout = dropout |
82 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
84 | self.mode = mode | ||
83 | self.template_key = template_key | 85 | self.template_key = template_key |
84 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
85 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
@@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
99 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
100 | None, | 102 | None, |
101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 103 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 104 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
103 | ) | 105 | ) |
104 | for item in data | 106 | for item in data |
105 | ] | 107 | ] |
@@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
118 | item.instance_image_path, | 120 | item.instance_image_path, |
119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 121 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
120 | item.prompt, | 122 | item.prompt, |
121 | item.nprompt | 123 | item.nprompt, |
122 | ) | 124 | ) |
123 | for item in items | 125 | for item in items |
124 | for i in range(image_multiplier) | 126 | for i in range(image_multiplier) |
@@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
130 | template = metadata[self.template_key] if self.template_key in metadata else {} | 132 | template = metadata[self.template_key] if self.template_key in metadata else {} |
131 | items = metadata["items"] if "items" in metadata else [] | 133 | items = metadata["items"] if "items" in metadata else [] |
132 | 134 | ||
133 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 135 | if self.mode is not None: |
136 | items = [ | ||
137 | item | ||
138 | for item in items | ||
139 | if "mode" in item and self.mode in item["mode"] | ||
140 | ] | ||
134 | items = self.prepare_items(template, items) | 141 | items = self.prepare_items(template, items) |
135 | items = self.filter_items(items) | 142 | items = self.filter_items(items) |
136 | 143 | ||