diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 11:44:42 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 11:44:42 +0200 |
| commit | d0ce16b542deac464e097c38adc5095802bd6763 (patch) | |
| tree | 8dc3295ac2a03c46eb1add566691134335ee6657 /data/dreambooth | |
| parent | Added script to convert Differs -> SD (diff) | |
| download | textual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.tar.gz textual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.tar.bz2 textual-inversion-diff-d0ce16b542deac464e097c38adc5095802bd6763.zip | |
Assign unused images in validation dataset to train dataset
Diffstat (limited to 'data/dreambooth')
| -rw-r--r-- | data/dreambooth/csv.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 1676d35..4087226 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -21,6 +21,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 21 | interpolation="bicubic", | 21 | interpolation="bicubic", |
| 22 | identifier="*", | 22 | identifier="*", |
| 23 | center_crop=False, | 23 | center_crop=False, |
| 24 | valid_set_size=None, | ||
| 24 | collate_fn=None): | 25 | collate_fn=None): |
| 25 | super().__init__() | 26 | super().__init__() |
| 26 | 27 | ||
| @@ -39,6 +40,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 39 | self.identifier = identifier | 40 | self.identifier = identifier |
| 40 | self.center_crop = center_crop | 41 | self.center_crop = center_crop |
| 41 | self.interpolation = interpolation | 42 | self.interpolation = interpolation |
| 43 | self.valid_set_size = valid_set_size | ||
| 42 | self.collate_fn = collate_fn | 44 | self.collate_fn = collate_fn |
| 43 | self.batch_size = batch_size | 45 | self.batch_size = batch_size |
| 44 | 46 | ||
| @@ -50,8 +52,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 50 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 52 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] |
| 51 | 53 | ||
| 52 | def setup(self, stage=None): | 54 | def setup(self, stage=None): |
| 53 | train_set_size = int(len(self.data_full) * 0.8) | 55 | valid_set_size = int(len(self.data_full) * 0.2) |
| 54 | valid_set_size = len(self.data_full) - train_set_size | 56 | if self.valid_set_size: |
| 57 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | ||
| 58 | train_set_size = len(self.data_full) - valid_set_size | ||
| 59 | |||
| 55 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 60 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
| 56 | 61 | ||
| 57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 62 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
