summaryrefslogtreecommitdiffstats
path: root/data/csv.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/csv.py')
-rw-r--r--data/csv.py83
1 files changed, 37 insertions, 46 deletions
diff --git a/data/csv.py b/data/csv.py
index 316c099..4c91ded 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,11 +1,14 @@
1import math 1import math
2import pandas as pd 2import pandas as pd
3import torch
3from pathlib import Path 4from pathlib import Path
4import pytorch_lightning as pl 5import pytorch_lightning as pl
5from PIL import Image 6from PIL import Image
6from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms 8from torchvision import transforms
8from typing import NamedTuple, List 9from typing import NamedTuple, List, Optional
10
11from models.clip.prompt import PromptProcessor
9 12
10 13
11class CSVDataItem(NamedTuple): 14class CSVDataItem(NamedTuple):
@@ -18,19 +21,19 @@ class CSVDataItem(NamedTuple):
18class CSVDataModule(pl.LightningDataModule): 21class CSVDataModule(pl.LightningDataModule):
19 def __init__( 22 def __init__(
20 self, 23 self,
21 batch_size, 24 batch_size: int,
22 data_file, 25 data_file: str,
23 tokenizer, 26 prompt_processor: PromptProcessor,
24 instance_identifier, 27 instance_identifier: str,
25 class_identifier=None, 28 class_identifier: Optional[str] = None,
26 class_subdir="cls", 29 class_subdir: str = "cls",
27 num_class_images=100, 30 num_class_images: int = 100,
28 size=512, 31 size: int = 512,
29 repeats=100, 32 repeats: int = 1,
30 interpolation="bicubic", 33 interpolation: str = "bicubic",
31 center_crop=False, 34 center_crop: bool = False,
32 valid_set_size=None, 35 valid_set_size: Optional[int] = None,
33 generator=None, 36 generator: Optional[torch.Generator] = None,
34 collate_fn=None 37 collate_fn=None
35 ): 38 ):
36 super().__init__() 39 super().__init__()
@@ -45,7 +48,7 @@ class CSVDataModule(pl.LightningDataModule):
45 self.class_root.mkdir(parents=True, exist_ok=True) 48 self.class_root.mkdir(parents=True, exist_ok=True)
46 self.num_class_images = num_class_images 49 self.num_class_images = num_class_images
47 50
48 self.tokenizer = tokenizer 51 self.prompt_processor = prompt_processor
49 self.instance_identifier = instance_identifier 52 self.instance_identifier = instance_identifier
50 self.class_identifier = class_identifier 53 self.class_identifier = class_identifier
51 self.size = size 54 self.size = size
@@ -65,7 +68,7 @@ class CSVDataModule(pl.LightningDataModule):
65 self.data_root.joinpath(item.image), 68 self.data_root.joinpath(item.image),
66 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), 69 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"),
67 item.prompt, 70 item.prompt,
68 item.nprompt if "nprompt" in item else "" 71 item.nprompt
69 ) 72 )
70 for item in data 73 for item in data
71 for i in range(image_multiplier) 74 for i in range(image_multiplier)
@@ -88,12 +91,12 @@ class CSVDataModule(pl.LightningDataModule):
88 self.data_val = self.prepare_subdata(data_val) 91 self.data_val = self.prepare_subdata(data_val)
89 92
90 def setup(self, stage=None): 93 def setup(self, stage=None):
91 train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, 94 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
92 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 95 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
93 num_class_images=self.num_class_images, 96 num_class_images=self.num_class_images,
94 size=self.size, interpolation=self.interpolation, 97 size=self.size, interpolation=self.interpolation,
95 center_crop=self.center_crop, repeats=self.repeats) 98 center_crop=self.center_crop, repeats=self.repeats)
96 val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, 99 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
97 instance_identifier=self.instance_identifier, 100 instance_identifier=self.instance_identifier,
98 size=self.size, interpolation=self.interpolation, 101 size=self.size, interpolation=self.interpolation,
99 center_crop=self.center_crop, repeats=self.repeats) 102 center_crop=self.center_crop, repeats=self.repeats)
@@ -113,19 +116,19 @@ class CSVDataset(Dataset):
113 def __init__( 116 def __init__(
114 self, 117 self,
115 data: List[CSVDataItem], 118 data: List[CSVDataItem],
116 tokenizer, 119 prompt_processor: PromptProcessor,
117 instance_identifier, 120 instance_identifier: str,
118 batch_size=1, 121 batch_size: int = 1,
119 class_identifier=None, 122 class_identifier: Optional[str] = None,
120 num_class_images=0, 123 num_class_images: int = 0,
121 size=512, 124 size: int = 512,
122 repeats=1, 125 repeats: int = 1,
123 interpolation="bicubic", 126 interpolation: str = "bicubic",
124 center_crop=False, 127 center_crop: bool = False,
125 ): 128 ):
126 129
127 self.data = data 130 self.data = data
128 self.tokenizer = tokenizer 131 self.prompt_processor = prompt_processor
129 self.batch_size = batch_size 132 self.batch_size = batch_size
130 self.instance_identifier = instance_identifier 133 self.instance_identifier = instance_identifier
131 self.class_identifier = class_identifier 134 self.class_identifier = class_identifier
@@ -163,12 +166,6 @@ class CSVDataset(Dataset):
163 166
164 example = {} 167 example = {}
165 168
166 if isinstance(item.prompt, str):
167 item.prompt = [item.prompt]
168
169 if isinstance(item.nprompt, str):
170 item.nprompt = [item.nprompt]
171
172 example["prompts"] = item.prompt 169 example["prompts"] = item.prompt
173 example["nprompts"] = item.nprompt 170 example["nprompts"] = item.nprompt
174 171
@@ -181,12 +178,9 @@ class CSVDataset(Dataset):
181 self.image_cache[item.instance_image_path] = instance_image 178 self.image_cache[item.instance_image_path] = instance_image
182 179
183 example["instance_images"] = instance_image 180 example["instance_images"] = instance_image
184 example["instance_prompt_ids"] = self.tokenizer( 181 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
185 item.prompt.format(self.instance_identifier), 182 item.prompt.format(self.instance_identifier)
186 padding="max_length", 183 )
187 truncation=True,
188 max_length=self.tokenizer.model_max_length,
189 ).input_ids
190 184
191 if self.num_class_images != 0: 185 if self.num_class_images != 0:
192 class_image = Image.open(item.class_image_path) 186 class_image = Image.open(item.class_image_path)
@@ -194,12 +188,9 @@ class CSVDataset(Dataset):
194 class_image = class_image.convert("RGB") 188 class_image = class_image.convert("RGB")
195 189
196 example["class_images"] = class_image 190 example["class_images"] = class_image
197 example["class_prompt_ids"] = self.tokenizer( 191 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(
198 item.prompt.format(self.class_identifier), 192 item.nprompt.format(self.class_identifier)
199 padding="max_length", 193 )
200 truncation=True,
201 max_length=self.tokenizer.model_max_length,
202 ).input_ids
203 194
204 self.cache[item.instance_image_path] = example 195 self.cache[item.instance_image_path] = example
205 return example 196 return example