summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py38
1 files changed, 31 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py
index 9125212..9c3c3f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule):
54 dropout: float = 0, 54 dropout: float = 0,
55 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
56 center_crop: bool = False, 56 center_crop: bool = False,
57 template_key: str = "template",
57 valid_set_size: Optional[int] = None, 58 valid_set_size: Optional[int] = None,
58 generator: Optional[torch.Generator] = None, 59 generator: Optional[torch.Generator] = None,
60 keyword_filter: list[str] = [],
59 collate_fn=None, 61 collate_fn=None,
60 num_workers: int = 0 62 num_workers: int = 0
61 ): 63 ):
@@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule):
78 self.repeats = repeats 80 self.repeats = repeats
79 self.dropout = dropout 81 self.dropout = dropout
80 self.center_crop = center_crop 82 self.center_crop = center_crop
83 self.template_key = template_key
81 self.interpolation = interpolation 84 self.interpolation = interpolation
82 self.valid_set_size = valid_set_size 85 self.valid_set_size = valid_set_size
83 self.generator = generator 86 self.generator = generator
87 self.keyword_filter = keyword_filter
84 self.collate_fn = collate_fn 88 self.collate_fn = collate_fn
85 self.num_workers = num_workers 89 self.num_workers = num_workers
86 self.batch_size = batch_size 90 self.batch_size = batch_size
87 91
88 def prepare_subdata(self, template, data, num_class_images=1): 92 def prepare_items(self, template, data) -> list[CSVDataItem]:
89 image = template["image"] if "image" in template else "{}" 93 image = template["image"] if "image" in template else "{}"
90 prompt = template["prompt"] if "prompt" in template else "{content}" 94 prompt = template["prompt"] if "prompt" in template else "{content}"
91 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 95 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
92 96
93 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
94
95 return [ 97 return [
96 CSVDataItem( 98 CSVDataItem(
97 self.data_root.joinpath(image.format(item["image"])), 99 self.data_root.joinpath(image.format(item["image"])),
98 self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), 100 None,
99 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 101 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
100 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) 102 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else ""))
101 ) 103 )
102 for item in data 104 for item in data
105 ]
106
107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]:
108 if len(self.keyword_filter) == 0:
109 return items
110
111 return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)]
112
113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1)
115
116 return [
117 CSVDataItem(
118 item.instance_image_path,
119 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
120 item.prompt,
121 item.nprompt
122 )
123 for item in items
103 for i in range(image_multiplier) 124 for i in range(image_multiplier)
104 ] 125 ]
105 126
106 def prepare_data(self): 127 def prepare_data(self):
107 with open(self.data_file, 'rt') as f: 128 with open(self.data_file, 'rt') as f:
108 metadata = json.load(f) 129 metadata = json.load(f)
109 template = metadata["template"] if "template" in metadata else {} 130 template = metadata[self.template_key] if self.template_key in metadata else {}
110 items = metadata["items"] if "items" in metadata else [] 131 items = metadata["items"] if "items" in metadata else []
111 132
112 items = [item for item in items if not "skip" in item or item["skip"] != True] 133 items = [item for item in items if not "skip" in item or item["skip"] != True]
134 items = self.prepare_items(template, items)
135 items = self.filter_items(items)
136
113 num_images = len(items) 137 num_images = len(items)
114 138
115 valid_set_size = int(num_images * 0.1) 139 valid_set_size = int(num_images * 0.1)
@@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule):
120 144
121 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) 145 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator)
122 146
123 self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) 147 self.data_train = self.pad_items(data_train, self.num_class_images)
124 self.data_val = self.prepare_subdata(template, data_val) 148 self.data_val = self.pad_items(data_val)
125 149
126 def setup(self, stage=None): 150 def setup(self, stage=None):
127 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 151 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,