From d0ce16b542deac464e097c38adc5095802bd6763 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 11:44:42 +0200 Subject: Assign unused images in validation dataset to train dataset --- data/dreambooth/csv.py | 9 +++++++-- data/textual_inversion/csv.py | 11 ++++++++--- dreambooth.py | 3 ++- textual_inversion.py | 12 +++++++++--- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 1676d35..4087226 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -21,6 +21,7 @@ class CSVDataModule(pl.LightningDataModule): interpolation="bicubic", identifier="*", center_crop=False, + valid_set_size=None, collate_fn=None): super().__init__() @@ -39,6 +40,7 @@ class CSVDataModule(pl.LightningDataModule): self.identifier = identifier self.center_crop = center_crop self.interpolation = interpolation + self.valid_set_size = valid_set_size self.collate_fn = collate_fn self.batch_size = batch_size @@ -50,8 +52,11 @@ class CSVDataModule(pl.LightningDataModule): self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] def setup(self, stage=None): - train_set_size = int(len(self.data_full) * 0.8) - valid_set_size = len(self.data_full) - train_set_size + valid_set_size = int(len(self.data_full) * 0.2) + if self.valid_set_size: + valid_set_size = math.min(valid_set_size, self.valid_set_size) + train_set_size = len(self.data_full) - valid_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index f306c7a..e082511 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -18,7 +18,8 @@ class CSVDataModule(pl.LightningDataModule): repeats=100, interpolation="bicubic", placeholder_token="*", - center_crop=False): + center_crop=False, + valid_set_size=None): super().__init__() self.data_file = Path(data_file) @@ -33,6 +34,7 @@ class CSVDataModule(pl.LightningDataModule): self.placeholder_token = placeholder_token self.center_crop = center_crop self.interpolation = interpolation + self.valid_set_size = valid_set_size self.batch_size = batch_size @@ -44,8 +46,11 @@ class CSVDataModule(pl.LightningDataModule): self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] def setup(self, stage=None): - train_set_size = int(len(self.data_full) * 0.8) - valid_set_size = len(self.data_full) - train_set_size + valid_set_size = int(len(self.data_full) * 0.2) + if self.valid_set_size: + valid_set_size = math.min(valid_set_size, self.valid_set_size) + train_set_size = len(self.data_full) - valid_set_size + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, diff --git a/dreambooth.py b/dreambooth.py index 744d1bc..88cd0da 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -638,6 +638,7 @@ def main(): identifier=args.identifier, repeats=args.repeats, center_crop=args.center_crop, + valid_set_size=args.sample_batch_size*args.stable_sample_batches, collate_fn=collate_fn) datamodule.prepare_data() @@ -658,7 +659,7 @@ def main(): sample_batch_size=args.sample_batch_size, random_sample_batches=args.random_sample_batches, stable_sample_batches=args.stable_sample_batches, - seed=args.seed + seed=args.seed or torch.random.seed() ) # Scheduler and math around the number of training steps. diff --git a/textual_inversion.py b/textual_inversion.py index 7a7d7fc..fa6214e 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -604,9 +604,15 @@ def main(): ) datamodule = CSVDataModule( - data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, - center_crop=args.center_crop) + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + size=args.resolution, + placeholder_token=args.placeholder_token, + repeats=args.repeats, + center_crop=args.center_crop, + valid_set_size=args.sample_batch_size*args.stable_sample_batches + ) datamodule.prepare_data() datamodule.setup() -- cgit v1.2.3-70-g09d2