From 46b6c09a18b41edff77c6881529b66733d788abe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 21:28:52 +0200 Subject: Dreambooth: Generate specialized class images from input prompts --- data/textual_inversion/csv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'data/textual_inversion') diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 852b1cb..4c5e27e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) + valid_set_size = max(valid_set_size, 1) train_set_size = len(self.data_full) - valid_set_size self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) - val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, + val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) -- cgit v1.2.3-54-g00ecf