diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 12:08:16 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 12:08:16 +0200 |
commit | 6c072fe50b3bfc561f22e5d591212d30de3c2dd2 (patch) | |
tree | e6dd60b5fa696d614ccc1cddb869c12c29f6ab46 /data | |
parent | Assign unused images in validation dataset to train dataset (diff) | |
download | textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.gz textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.bz2 textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.zip |
Fixed euler_a generator argument
Diffstat (limited to 'data')
-rw-r--r-- | data/dreambooth/csv.py | 6 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 8 |
2 files changed, 9 insertions, 5 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4087226..08ed49c 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
22 | identifier="*", | 22 | identifier="*", |
23 | center_crop=False, | 23 | center_crop=False, |
24 | valid_set_size=None, | 24 | valid_set_size=None, |
25 | generator=None, | ||
25 | collate_fn=None): | 26 | collate_fn=None): |
26 | super().__init__() | 27 | super().__init__() |
27 | 28 | ||
@@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
41 | self.center_crop = center_crop | 42 | self.center_crop = center_crop |
42 | self.interpolation = interpolation | 43 | self.interpolation = interpolation |
43 | self.valid_set_size = valid_set_size | 44 | self.valid_set_size = valid_set_size |
45 | self.generator = generator | ||
44 | self.collate_fn = collate_fn | 46 | self.collate_fn = collate_fn |
45 | self.batch_size = batch_size | 47 | self.batch_size = batch_size |
46 | 48 | ||
@@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
54 | def setup(self, stage=None): | 56 | def setup(self, stage=None): |
55 | valid_set_size = int(len(self.data_full) * 0.2) | 57 | valid_set_size = int(len(self.data_full) * 0.2) |
56 | if self.valid_set_size: | 58 | if self.valid_set_size: |
57 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | 59 | valid_set_size = min(valid_set_size, self.valid_set_size) |
58 | train_set_size = len(self.data_full) - valid_set_size | 60 | train_set_size = len(self.data_full) - valid_set_size |
59 | 61 | ||
60 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 62 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
61 | 63 | ||
62 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 64 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
63 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 65 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index e082511..3ac57df 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -19,7 +19,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
19 | interpolation="bicubic", | 19 | interpolation="bicubic", |
20 | placeholder_token="*", | 20 | placeholder_token="*", |
21 | center_crop=False, | 21 | center_crop=False, |
22 | valid_set_size=None): | 22 | valid_set_size=None, |
23 | generator=None): | ||
23 | super().__init__() | 24 | super().__init__() |
24 | 25 | ||
25 | self.data_file = Path(data_file) | 26 | self.data_file = Path(data_file) |
@@ -35,6 +36,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
35 | self.center_crop = center_crop | 36 | self.center_crop = center_crop |
36 | self.interpolation = interpolation | 37 | self.interpolation = interpolation |
37 | self.valid_set_size = valid_set_size | 38 | self.valid_set_size = valid_set_size |
39 | self.generator = generator | ||
38 | 40 | ||
39 | self.batch_size = batch_size | 41 | self.batch_size = batch_size |
40 | 42 | ||
@@ -48,10 +50,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
48 | def setup(self, stage=None): | 50 | def setup(self, stage=None): |
49 | valid_set_size = int(len(self.data_full) * 0.2) | 51 | valid_set_size = int(len(self.data_full) * 0.2) |
50 | if self.valid_set_size: | 52 | if self.valid_set_size: |
51 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | 53 | valid_set_size = min(valid_set_size, self.valid_set_size) |
52 | train_set_size = len(self.data_full) - valid_set_size | 54 | train_set_size = len(self.data_full) - valid_set_size |
53 | 55 | ||
54 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 56 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
55 | 57 | ||
56 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 58 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
57 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 59 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |