diff options
-rw-r--r-- | data/dreambooth/csv.py | 9 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 11 | ||||
-rw-r--r-- | dreambooth.py | 3 | ||||
-rw-r--r-- | textual_inversion.py | 12 |
4 files changed, 26 insertions, 9 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 1676d35..4087226 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -21,6 +21,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
21 | interpolation="bicubic", | 21 | interpolation="bicubic", |
22 | identifier="*", | 22 | identifier="*", |
23 | center_crop=False, | 23 | center_crop=False, |
24 | valid_set_size=None, | ||
24 | collate_fn=None): | 25 | collate_fn=None): |
25 | super().__init__() | 26 | super().__init__() |
26 | 27 | ||
@@ -39,6 +40,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
39 | self.identifier = identifier | 40 | self.identifier = identifier |
40 | self.center_crop = center_crop | 41 | self.center_crop = center_crop |
41 | self.interpolation = interpolation | 42 | self.interpolation = interpolation |
43 | self.valid_set_size = valid_set_size | ||
42 | self.collate_fn = collate_fn | 44 | self.collate_fn = collate_fn |
43 | self.batch_size = batch_size | 45 | self.batch_size = batch_size |
44 | 46 | ||
@@ -50,8 +52,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
50 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 52 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] |
51 | 53 | ||
52 | def setup(self, stage=None): | 54 | def setup(self, stage=None): |
53 | train_set_size = int(len(self.data_full) * 0.8) | 55 | valid_set_size = int(len(self.data_full) * 0.2) |
54 | valid_set_size = len(self.data_full) - train_set_size | 56 | if self.valid_set_size: |
57 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | ||
58 | train_set_size = len(self.data_full) - valid_set_size | ||
59 | |||
55 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 60 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
56 | 61 | ||
57 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 62 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index f306c7a..e082511 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -18,7 +18,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
18 | repeats=100, | 18 | repeats=100, |
19 | interpolation="bicubic", | 19 | interpolation="bicubic", |
20 | placeholder_token="*", | 20 | placeholder_token="*", |
21 | center_crop=False): | 21 | center_crop=False, |
22 | valid_set_size=None): | ||
22 | super().__init__() | 23 | super().__init__() |
23 | 24 | ||
24 | self.data_file = Path(data_file) | 25 | self.data_file = Path(data_file) |
@@ -33,6 +34,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
33 | self.placeholder_token = placeholder_token | 34 | self.placeholder_token = placeholder_token |
34 | self.center_crop = center_crop | 35 | self.center_crop = center_crop |
35 | self.interpolation = interpolation | 36 | self.interpolation = interpolation |
37 | self.valid_set_size = valid_set_size | ||
36 | 38 | ||
37 | self.batch_size = batch_size | 39 | self.batch_size = batch_size |
38 | 40 | ||
@@ -44,8 +46,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
44 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 46 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] |
45 | 47 | ||
46 | def setup(self, stage=None): | 48 | def setup(self, stage=None): |
47 | train_set_size = int(len(self.data_full) * 0.8) | 49 | valid_set_size = int(len(self.data_full) * 0.2) |
48 | valid_set_size = len(self.data_full) - train_set_size | 50 | if self.valid_set_size: |
51 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | ||
52 | train_set_size = len(self.data_full) - valid_set_size | ||
53 | |||
49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 54 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
50 | 55 | ||
51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 56 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
diff --git a/dreambooth.py b/dreambooth.py index 744d1bc..88cd0da 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -638,6 +638,7 @@ def main(): | |||
638 | identifier=args.identifier, | 638 | identifier=args.identifier, |
639 | repeats=args.repeats, | 639 | repeats=args.repeats, |
640 | center_crop=args.center_crop, | 640 | center_crop=args.center_crop, |
641 | valid_set_size=args.sample_batch_size*args.stable_sample_batches, | ||
641 | collate_fn=collate_fn) | 642 | collate_fn=collate_fn) |
642 | 643 | ||
643 | datamodule.prepare_data() | 644 | datamodule.prepare_data() |
@@ -658,7 +659,7 @@ def main(): | |||
658 | sample_batch_size=args.sample_batch_size, | 659 | sample_batch_size=args.sample_batch_size, |
659 | random_sample_batches=args.random_sample_batches, | 660 | random_sample_batches=args.random_sample_batches, |
660 | stable_sample_batches=args.stable_sample_batches, | 661 | stable_sample_batches=args.stable_sample_batches, |
661 | seed=args.seed | 662 | seed=args.seed or torch.random.seed() |
662 | ) | 663 | ) |
663 | 664 | ||
664 | # Scheduler and math around the number of training steps. | 665 | # Scheduler and math around the number of training steps. |
diff --git a/textual_inversion.py b/textual_inversion.py index 7a7d7fc..fa6214e 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -604,9 +604,15 @@ def main(): | |||
604 | ) | 604 | ) |
605 | 605 | ||
606 | datamodule = CSVDataModule( | 606 | datamodule = CSVDataModule( |
607 | data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, | 607 | data_file=args.train_data_file, |
608 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, | 608 | batch_size=args.train_batch_size, |
609 | center_crop=args.center_crop) | 609 | tokenizer=tokenizer, |
610 | size=args.resolution, | ||
611 | placeholder_token=args.placeholder_token, | ||
612 | repeats=args.repeats, | ||
613 | center_crop=args.center_crop, | ||
614 | valid_set_size=args.sample_batch_size*args.stable_sample_batches | ||
615 | ) | ||
610 | 616 | ||
611 | datamodule.prepare_data() | 617 | datamodule.prepare_data() |
612 | datamodule.setup() | 618 | datamodule.setup() |