diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-11 12:59:13 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-11 12:59:13 +0100 |
| commit | 8f2b8e8d309470babd9b853fde8f0a081366deae (patch) | |
| tree | 1374e791705e31fa77fefeb5001aad204cdf3224 /data | |
| parent | Support attention_mask of text encoder (diff) | |
| download | textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.gz textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.bz2 textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.zip | |
Training improvements such as tag dropout
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 23b5299..9125212 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): | |||
| 16 | return {"content": prompt} if isinstance(prompt, str) else prompt | 16 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | def shuffle_prompt(prompt: str): | 19 | def shuffle_prompt(prompt: str, dropout: float = 0): |
| 20 | def handle_block(block: str): | 20 | def handle_block(block: str): |
| 21 | words = block.split(", ") | 21 | words = block.split(", ") |
| 22 | words = [w for w in words if w != ""] | ||
| 23 | if dropout != 0: | ||
| 24 | words = [w for w in words if np.random.random() > dropout] | ||
| 22 | np.random.shuffle(words) | 25 | np.random.shuffle(words) |
| 23 | return ", ".join(words) | 26 | return ", ".join(words) |
| 24 | 27 | ||
| 25 | prompt = prompt.split(". ") | 28 | prompt = prompt.split(". ") |
| 26 | prompt = [handle_block(b) for b in prompt] | 29 | prompt = [handle_block(b) for b in prompt if b != ""] |
| 27 | np.random.shuffle(prompt) | 30 | np.random.shuffle(prompt) |
| 28 | prompt = ". ".join(prompt) | 31 | prompt = ". ".join(prompt) |
| 29 | return prompt | 32 | return prompt |
| @@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 48 | num_class_images: int = 100, | 51 | num_class_images: int = 100, |
| 49 | size: int = 512, | 52 | size: int = 512, |
| 50 | repeats: int = 1, | 53 | repeats: int = 1, |
| 54 | dropout: float = 0, | ||
| 51 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
| 52 | center_crop: bool = False, | 56 | center_crop: bool = False, |
| 53 | valid_set_size: Optional[int] = None, | 57 | valid_set_size: Optional[int] = None, |
| @@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 72 | self.class_identifier = class_identifier | 76 | self.class_identifier = class_identifier |
| 73 | self.size = size | 77 | self.size = size |
| 74 | self.repeats = repeats | 78 | self.repeats = repeats |
| 79 | self.dropout = dropout | ||
| 75 | self.center_crop = center_crop | 80 | self.center_crop = center_crop |
| 76 | self.interpolation = interpolation | 81 | self.interpolation = interpolation |
| 77 | self.valid_set_size = valid_set_size | 82 | self.valid_set_size = valid_set_size |
| @@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 123 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 128 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 124 | num_class_images=self.num_class_images, | 129 | num_class_images=self.num_class_images, |
| 125 | size=self.size, interpolation=self.interpolation, | 130 | size=self.size, interpolation=self.interpolation, |
| 126 | center_crop=self.center_crop, repeats=self.repeats) | 131 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) |
| 127 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 132 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 128 | instance_identifier=self.instance_identifier, | 133 | instance_identifier=self.instance_identifier, |
| 129 | size=self.size, interpolation=self.interpolation, | 134 | size=self.size, interpolation=self.interpolation, |
| @@ -153,6 +158,7 @@ class CSVDataset(Dataset): | |||
| 153 | num_class_images: int = 0, | 158 | num_class_images: int = 0, |
| 154 | size: int = 512, | 159 | size: int = 512, |
| 155 | repeats: int = 1, | 160 | repeats: int = 1, |
| 161 | dropout: float = 0, | ||
| 156 | interpolation: str = "bicubic", | 162 | interpolation: str = "bicubic", |
| 157 | center_crop: bool = False, | 163 | center_crop: bool = False, |
| 158 | ): | 164 | ): |
| @@ -163,6 +169,7 @@ class CSVDataset(Dataset): | |||
| 163 | self.instance_identifier = instance_identifier | 169 | self.instance_identifier = instance_identifier |
| 164 | self.class_identifier = class_identifier | 170 | self.class_identifier = class_identifier |
| 165 | self.num_class_images = num_class_images | 171 | self.num_class_images = num_class_images |
| 172 | self.dropout = dropout | ||
| 166 | self.image_cache = {} | 173 | self.image_cache = {} |
| 167 | 174 | ||
| 168 | self.num_instance_images = len(self.data) | 175 | self.num_instance_images = len(self.data) |
