diff options
author | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
commit | 94b676d91382267e7429bd68362019868affd9d1 (patch) | |
tree | 513697739ab25217cbfcff630299d02b1f6e98c8 /data | |
parent | Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff) | |
download | textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2 textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip |
Update
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py index b4c81d7..c5902ed 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -42,7 +42,7 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
42 | 42 | ||
43 | 43 | ||
44 | def generate_buckets( | 44 | def generate_buckets( |
45 | items: list[str], | 45 | items: Union[list[str], list[Path]], |
46 | base_size: int, | 46 | base_size: int, |
47 | step_size: int = 64, | 47 | step_size: int = 64, |
48 | max_pixels: Optional[int] = None, | 48 | max_pixels: Optional[int] = None, |
@@ -188,7 +188,7 @@ class VlpnDataModule(): | |||
188 | raise ValueError("data_file must be a file") | 188 | raise ValueError("data_file must be a file") |
189 | 189 | ||
190 | self.data_root = self.data_file.parent | 190 | self.data_root = self.data_file.parent |
191 | self.class_root = self.data_root.joinpath(class_subdir) | 191 | self.class_root = self.data_root / class_subdir |
192 | self.class_root.mkdir(parents=True, exist_ok=True) | 192 | self.class_root.mkdir(parents=True, exist_ok=True) |
193 | self.num_class_images = num_class_images | 193 | self.num_class_images = num_class_images |
194 | 194 | ||
@@ -218,7 +218,7 @@ class VlpnDataModule(): | |||
218 | 218 | ||
219 | return [ | 219 | return [ |
220 | VlpnDataItem( | 220 | VlpnDataItem( |
221 | self.data_root.joinpath(image.format(item["image"])), | 221 | self.data_root / image.format(item["image"]), |
222 | None, | 222 | None, |
223 | prompt_to_keywords( | 223 | prompt_to_keywords( |
224 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 224 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
@@ -249,7 +249,7 @@ class VlpnDataModule(): | |||
249 | return [ | 249 | return [ |
250 | VlpnDataItem( | 250 | VlpnDataItem( |
251 | item.instance_image_path, | 251 | item.instance_image_path, |
252 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 252 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", |
253 | item.prompt, | 253 | item.prompt, |
254 | item.cprompt, | 254 | item.cprompt, |
255 | item.nprompt, | 255 | item.nprompt, |