From 2a65b4eb29e4874c153a9517ab06b93481c2d238 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 18:32:15 +0200 Subject: Batches of size 1 cause error: Expected query.is_contiguous() to be true, but got false --- data/dreambooth/csv.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'data/dreambooth') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 85ed4a5..99bcf12 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -1,3 +1,4 @@ +import math import os import pandas as pd from pathlib import Path @@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule): train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop) + center_crop=self.center_crop, batch_size=self.batch_size) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn) self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) @@ -85,22 +85,24 @@ class CSVDataset(Dataset): interpolation="bicubic", identifier="*", center_crop=False, + batch_size=1, ): self.data = data self.tokenizer = tokenizer self.instance_prompt = instance_prompt + self.identifier = identifier + self.batch_size = batch_size + self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - self.identifier = identifier - if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images = list(Path(class_data_root).iterdir()) + self.class_images = list(self.class_data_root.iterdir()) self.num_class_images = len(self.class_images) self._length = max(self.num_class_images, self.num_instance_images) @@ -123,10 +125,8 @@ class CSVDataset(Dataset): ] ) - self.cache = {} - def __len__(self): - return self._length + return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): image_path, text = self.data[i % self.num_instance_images] -- cgit v1.2.3-70-g09d2