From 799a2ed9c9735d11887600ee57ebb7471cdf6f43 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 14:04:59 +0100 Subject: Misc improvements --- data/csv.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 4da5d64..803271b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple): prompt: list[str] cprompt: str nprompt: str - mode: list[str] + collection: list[str] class CSVDataModule(): def __init__( - self, - batch_size: int, - data_file: str, - prompt_processor: PromptProcessor, - class_subdir: str = "cls", - num_class_images: int = 1, - size: int = 768, - repeats: int = 1, - dropout: float = 0, - interpolation: str = "bicubic", - center_crop: bool = False, - template_key: str = "template", - valid_set_size: Optional[int] = None, - generator: Optional[torch.Generator] = None, - filter: Optional[Callable[[CSVDataItem], bool]] = None, - collate_fn=None, - num_workers: int = 0 + self, + batch_size: int, + data_file: str, + prompt_processor: PromptProcessor, + class_subdir: str = "cls", + num_class_images: int = 1, + size: int = 768, + repeats: int = 1, + dropout: float = 0, + interpolation: str = "bicubic", + center_crop: bool = False, + template_key: str = "template", + valid_set_size: Optional[int] = None, + generator: Optional[torch.Generator] = None, + filter: Optional[Callable[[CSVDataItem], bool]] = None, + collate_fn=None, + num_workers: int = 0 ): super().__init__() @@ -112,7 +112,7 @@ class CSVDataModule(): nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), - item["mode"].split(", ") if "mode" in item else [] + item["collection"].split(", ") if "collection" in item else [] ) for item in data ] @@ -133,7 +133,7 @@ class CSVDataModule(): item.prompt, item.cprompt, item.nprompt, - item.mode, + item.collection, ) for item in items for i in range(image_multiplier) -- cgit v1.2.3-54-g00ecf