From 94b676d91382267e7429bd68362019868affd9d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 13 Feb 2023 17:19:18 +0100 Subject: Update --- data/csv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index b4c81d7..c5902ed 100644 --- a/data/csv.py +++ b/data/csv.py @@ -42,7 +42,7 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): def generate_buckets( - items: list[str], + items: Union[list[str], list[Path]], base_size: int, step_size: int = 64, max_pixels: Optional[int] = None, @@ -188,7 +188,7 @@ class VlpnDataModule(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath(class_subdir) + self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images @@ -218,7 +218,7 @@ class VlpnDataModule(): return [ VlpnDataItem( - self.data_root.joinpath(image.format(item["image"])), + self.data_root / image.format(item["image"]), None, prompt_to_keywords( prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), @@ -249,7 +249,7 @@ class VlpnDataModule(): return [ VlpnDataItem( item.instance_image_path, - self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), + self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.prompt, item.cprompt, item.nprompt, -- cgit v1.2.3-70-g09d2