diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/data/csv.py b/data/csv.py index abd329d..dcaf7d3 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,5 +1,3 @@ | |||
1 | import math | ||
2 | import os | ||
3 | import pandas as pd | 1 | import pandas as pd |
4 | from pathlib import Path | 2 | from pathlib import Path |
5 | import pytorch_lightning as pl | 3 | import pytorch_lightning as pl |
@@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
16 | instance_identifier, | 14 | instance_identifier, |
17 | class_identifier=None, | 15 | class_identifier=None, |
18 | class_subdir="db_cls", | 16 | class_subdir="db_cls", |
17 | num_class_images=2, | ||
19 | size=512, | 18 | size=512, |
20 | repeats=100, | 19 | repeats=100, |
21 | interpolation="bicubic", | 20 | interpolation="bicubic", |
@@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
33 | self.data_root = self.data_file.parent | 32 | self.data_root = self.data_file.parent |
34 | self.class_root = self.data_root.joinpath(class_subdir) | 33 | self.class_root = self.data_root.joinpath(class_subdir) |
35 | self.class_root.mkdir(parents=True, exist_ok=True) | 34 | self.class_root.mkdir(parents=True, exist_ok=True) |
35 | self.num_class_images = num_class_images | ||
36 | 36 | ||
37 | self.tokenizer = tokenizer | 37 | self.tokenizer = tokenizer |
38 | self.instance_identifier = instance_identifier | 38 | self.instance_identifier = instance_identifier |
@@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule): | |||
48 | 48 | ||
49 | def prepare_data(self): | 49 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
51 | instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] | 51 | instance_image_paths = [ |
52 | class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] | 52 | self.data_root.joinpath(f) |
53 | prompts = metadata['prompt'].values | 53 | for f in metadata['image'].values |
54 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) | 54 | for i in range(self.num_class_images) |
55 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) | 55 | ] |
56 | self.data = [(i, c, p, n) | 56 | class_image_paths = [ |
57 | for i, c, p, n, s | 57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") |
58 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | 58 | for f in metadata['image'].values |
59 | if s != "x"] | 59 | for i in range(self.num_class_images) |
60 | ] | ||
61 | prompts = [ | ||
62 | prompt | ||
63 | for prompt in metadata['prompt'].values | ||
64 | for i in range(self.num_class_images) | ||
65 | ] | ||
66 | nprompts = [ | ||
67 | nprompt | ||
68 | for nprompt in metadata['nprompt'].values | ||
69 | for i in range(self.num_class_images) | ||
70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
71 | skips = [ | ||
72 | skip | ||
73 | for skip in metadata['skip'].values | ||
74 | for i in range(self.num_class_images) | ||
75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
76 | self.data = [ | ||
77 | (i, c, p, n) | ||
78 | for i, c, p, n, s | ||
79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
80 | if s != "x" | ||
81 | ] | ||
60 | 82 | ||
61 | def setup(self, stage=None): | 83 | def setup(self, stage=None): |
62 | valid_set_size = int(len(self.data) * 0.2) | 84 | valid_set_size = int(len(self.data) * 0.2) |
@@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
69 | 91 | ||
70 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, |
71 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
94 | num_class_images=self.num_class_images, | ||
72 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
73 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
74 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, |
@@ -93,6 +116,7 @@ class CSVDataset(Dataset): | |||
93 | tokenizer, | 116 | tokenizer, |
94 | instance_identifier, | 117 | instance_identifier, |
95 | class_identifier=None, | 118 | class_identifier=None, |
119 | num_class_images=2, | ||
96 | size=512, | 120 | size=512, |
97 | repeats=1, | 121 | repeats=1, |
98 | interpolation="bicubic", | 122 | interpolation="bicubic", |
@@ -103,6 +127,7 @@ class CSVDataset(Dataset): | |||
103 | self.tokenizer = tokenizer | 127 | self.tokenizer = tokenizer |
104 | self.instance_identifier = instance_identifier | 128 | self.instance_identifier = instance_identifier |
105 | self.class_identifier = class_identifier | 129 | self.class_identifier = class_identifier |
130 | self.num_class_images = num_class_images | ||
106 | self.cache = {} | 131 | self.cache = {} |
107 | 132 | ||
108 | self.num_instance_images = len(self.data) | 133 | self.num_instance_images = len(self.data) |
@@ -128,9 +153,10 @@ class CSVDataset(Dataset): | |||
128 | 153 | ||
129 | def get_example(self, i): | 154 | def get_example(self, i): |
130 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
156 | cache_key = f"{instance_image_path}_{class_image_path}" | ||
131 | 157 | ||
132 | if instance_image_path in self.cache: | 158 | if cache_key in self.cache: |
133 | return self.cache[instance_image_path] | 159 | return self.cache[cache_key] |
134 | 160 | ||
135 | example = {} | 161 | example = {} |
136 | 162 | ||
@@ -149,7 +175,7 @@ class CSVDataset(Dataset): | |||
149 | max_length=self.tokenizer.model_max_length, | 175 | max_length=self.tokenizer.model_max_length, |
150 | ).input_ids | 176 | ).input_ids |
151 | 177 | ||
152 | if self.class_identifier is not None: | 178 | if self.num_class_images != 0: |
153 | class_image = Image.open(class_image_path) | 179 | class_image = Image.open(class_image_path) |
154 | if not class_image.mode == "RGB": | 180 | if not class_image.mode == "RGB": |
155 | class_image = class_image.convert("RGB") | 181 | class_image = class_image.convert("RGB") |