diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 23b5299..9125212 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
16 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
17 | 17 | ||
18 | 18 | ||
19 | def shuffle_prompt(prompt: str): | 19 | def shuffle_prompt(prompt: str, dropout: float = 0): |
20 | def handle_block(block: str): | 20 | def handle_block(block: str): |
21 | words = block.split(", ") | 21 | words = block.split(", ") |
22 | words = [w for w in words if w != ""] | ||
23 | if dropout != 0: | ||
24 | words = [w for w in words if np.random.random() > dropout] | ||
22 | np.random.shuffle(words) | 25 | np.random.shuffle(words) |
23 | return ", ".join(words) | 26 | return ", ".join(words) |
24 | 27 | ||
25 | prompt = prompt.split(". ") | 28 | prompt = prompt.split(". ") |
26 | prompt = [handle_block(b) for b in prompt] | 29 | prompt = [handle_block(b) for b in prompt if b != ""] |
27 | np.random.shuffle(prompt) | 30 | np.random.shuffle(prompt) |
28 | prompt = ". ".join(prompt) | 31 | prompt = ". ".join(prompt) |
29 | return prompt | 32 | return prompt |
@@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
48 | num_class_images: int = 100, | 51 | num_class_images: int = 100, |
49 | size: int = 512, | 52 | size: int = 512, |
50 | repeats: int = 1, | 53 | repeats: int = 1, |
54 | dropout: float = 0, | ||
51 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
52 | center_crop: bool = False, | 56 | center_crop: bool = False, |
53 | valid_set_size: Optional[int] = None, | 57 | valid_set_size: Optional[int] = None, |
@@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
72 | self.class_identifier = class_identifier | 76 | self.class_identifier = class_identifier |
73 | self.size = size | 77 | self.size = size |
74 | self.repeats = repeats | 78 | self.repeats = repeats |
79 | self.dropout = dropout | ||
75 | self.center_crop = center_crop | 80 | self.center_crop = center_crop |
76 | self.interpolation = interpolation | 81 | self.interpolation = interpolation |
77 | self.valid_set_size = valid_set_size | 82 | self.valid_set_size = valid_set_size |
@@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
123 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 128 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
124 | num_class_images=self.num_class_images, | 129 | num_class_images=self.num_class_images, |
125 | size=self.size, interpolation=self.interpolation, | 130 | size=self.size, interpolation=self.interpolation, |
126 | center_crop=self.center_crop, repeats=self.repeats) | 131 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
127 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 132 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
128 | instance_identifier=self.instance_identifier, | 133 | instance_identifier=self.instance_identifier, |
129 | size=self.size, interpolation=self.interpolation, | 134 | size=self.size, interpolation=self.interpolation, |
@@ -153,6 +158,7 @@ class CSVDataset(Dataset): | |||
153 | num_class_images: int = 0, | 158 | num_class_images: int = 0, |
154 | size: int = 512, | 159 | size: int = 512, |
155 | repeats: int = 1, | 160 | repeats: int = 1, |
161 | dropout: float = 0, | ||
156 | interpolation: str = "bicubic", | 162 | interpolation: str = "bicubic", |
157 | center_crop: bool = False, | 163 | center_crop: bool = False, |
158 | ): | 164 | ): |
@@ -163,6 +169,7 @@ class CSVDataset(Dataset): | |||
163 | self.instance_identifier = instance_identifier | 169 | self.instance_identifier = instance_identifier |
164 | self.class_identifier = class_identifier | 170 | self.class_identifier = class_identifier |
165 | self.num_class_images = num_class_images | 171 | self.num_class_images = num_class_images |
172 | self.dropout = dropout | ||
166 | self.image_cache = {} | 173 | self.image_cache = {} |
167 | 174 | ||
168 | self.num_instance_images = len(self.data) | 175 | self.num_instance_images = len(self.data) |