diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/data/csv.py b/data/csv.py index abd329d..dcaf7d3 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,5 +1,3 @@ | |||
| 1 | import math | ||
| 2 | import os | ||
| 3 | import pandas as pd | 1 | import pandas as pd |
| 4 | from pathlib import Path | 2 | from pathlib import Path |
| 5 | import pytorch_lightning as pl | 3 | import pytorch_lightning as pl |
| @@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 16 | instance_identifier, | 14 | instance_identifier, |
| 17 | class_identifier=None, | 15 | class_identifier=None, |
| 18 | class_subdir="db_cls", | 16 | class_subdir="db_cls", |
| 17 | num_class_images=2, | ||
| 19 | size=512, | 18 | size=512, |
| 20 | repeats=100, | 19 | repeats=100, |
| 21 | interpolation="bicubic", | 20 | interpolation="bicubic", |
| @@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 33 | self.data_root = self.data_file.parent | 32 | self.data_root = self.data_file.parent |
| 34 | self.class_root = self.data_root.joinpath(class_subdir) | 33 | self.class_root = self.data_root.joinpath(class_subdir) |
| 35 | self.class_root.mkdir(parents=True, exist_ok=True) | 34 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 35 | self.num_class_images = num_class_images | ||
| 36 | 36 | ||
| 37 | self.tokenizer = tokenizer | 37 | self.tokenizer = tokenizer |
| 38 | self.instance_identifier = instance_identifier | 38 | self.instance_identifier = instance_identifier |
| @@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 48 | 48 | ||
| 49 | def prepare_data(self): | 49 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
| 51 | instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] | 51 | instance_image_paths = [ |
| 52 | class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] | 52 | self.data_root.joinpath(f) |
| 53 | prompts = metadata['prompt'].values | 53 | for f in metadata['image'].values |
| 54 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) | 54 | for i in range(self.num_class_images) |
| 55 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) | 55 | ] |
| 56 | self.data = [(i, c, p, n) | 56 | class_image_paths = [ |
| 57 | for i, c, p, n, s | 57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") |
| 58 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | 58 | for f in metadata['image'].values |
| 59 | if s != "x"] | 59 | for i in range(self.num_class_images) |
| 60 | ] | ||
| 61 | prompts = [ | ||
| 62 | prompt | ||
| 63 | for prompt in metadata['prompt'].values | ||
| 64 | for i in range(self.num_class_images) | ||
| 65 | ] | ||
| 66 | nprompts = [ | ||
| 67 | nprompt | ||
| 68 | for nprompt in metadata['nprompt'].values | ||
| 69 | for i in range(self.num_class_images) | ||
| 70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
| 71 | skips = [ | ||
| 72 | skip | ||
| 73 | for skip in metadata['skip'].values | ||
| 74 | for i in range(self.num_class_images) | ||
| 75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
| 76 | self.data = [ | ||
| 77 | (i, c, p, n) | ||
| 78 | for i, c, p, n, s | ||
| 79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
| 80 | if s != "x" | ||
| 81 | ] | ||
| 60 | 82 | ||
| 61 | def setup(self, stage=None): | 83 | def setup(self, stage=None): |
| 62 | valid_set_size = int(len(self.data) * 0.2) | 84 | valid_set_size = int(len(self.data) * 0.2) |
| @@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 69 | 91 | ||
| 70 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, |
| 71 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 94 | num_class_images=self.num_class_images, | ||
| 72 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
| 73 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
| 74 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, |
| @@ -93,6 +116,7 @@ class CSVDataset(Dataset): | |||
| 93 | tokenizer, | 116 | tokenizer, |
| 94 | instance_identifier, | 117 | instance_identifier, |
| 95 | class_identifier=None, | 118 | class_identifier=None, |
| 119 | num_class_images=2, | ||
| 96 | size=512, | 120 | size=512, |
| 97 | repeats=1, | 121 | repeats=1, |
| 98 | interpolation="bicubic", | 122 | interpolation="bicubic", |
| @@ -103,6 +127,7 @@ class CSVDataset(Dataset): | |||
| 103 | self.tokenizer = tokenizer | 127 | self.tokenizer = tokenizer |
| 104 | self.instance_identifier = instance_identifier | 128 | self.instance_identifier = instance_identifier |
| 105 | self.class_identifier = class_identifier | 129 | self.class_identifier = class_identifier |
| 130 | self.num_class_images = num_class_images | ||
| 106 | self.cache = {} | 131 | self.cache = {} |
| 107 | 132 | ||
| 108 | self.num_instance_images = len(self.data) | 133 | self.num_instance_images = len(self.data) |
| @@ -128,9 +153,10 @@ class CSVDataset(Dataset): | |||
| 128 | 153 | ||
| 129 | def get_example(self, i): | 154 | def get_example(self, i): |
| 130 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 156 | cache_key = f"{instance_image_path}_{class_image_path}" | ||
| 131 | 157 | ||
| 132 | if instance_image_path in self.cache: | 158 | if cache_key in self.cache: |
| 133 | return self.cache[instance_image_path] | 159 | return self.cache[cache_key] |
| 134 | 160 | ||
| 135 | example = {} | 161 | example = {} |
| 136 | 162 | ||
| @@ -149,7 +175,7 @@ class CSVDataset(Dataset): | |||
| 149 | max_length=self.tokenizer.model_max_length, | 175 | max_length=self.tokenizer.model_max_length, |
| 150 | ).input_ids | 176 | ).input_ids |
| 151 | 177 | ||
| 152 | if self.class_identifier is not None: | 178 | if self.num_class_images != 0: |
| 153 | class_image = Image.open(class_image_path) | 179 | class_image = Image.open(class_image_path) |
| 154 | if not class_image.mode == "RGB": | 180 | if not class_image.mode == "RGB": |
| 155 | class_image = class_image.convert("RGB") | 181 | class_image = class_image.convert("RGB") |
