summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 14:04:59 +0100
commit799a2ed9c9735d11887600ee57ebb7471cdf6f43 (patch)
tree22a982d7348762f3cc55e91ba1e173f14c86cb99 /data
parentTraining script improvements (diff)
downloadtextual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.gz
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.tar.bz2
textual-inversion-diff-799a2ed9c9735d11887600ee57ebb7471cdf6f43.zip
Misc improvements
Diffstat (limited to 'data')
-rw-r--r--data/csv.py40
1 files changed, 20 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py
index 4da5d64..803271b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple):
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: str 42 cprompt: str
43 nprompt: str 43 nprompt: str
44 mode: list[str] 44 collection: list[str]
45 45
46 46
47class CSVDataModule(): 47class CSVDataModule():
48 def __init__( 48 def __init__(
49 self, 49 self,
50 batch_size: int, 50 batch_size: int,
51 data_file: str, 51 data_file: str,
52 prompt_processor: PromptProcessor, 52 prompt_processor: PromptProcessor,
53 class_subdir: str = "cls", 53 class_subdir: str = "cls",
54 num_class_images: int = 1, 54 num_class_images: int = 1,
55 size: int = 768, 55 size: int = 768,
56 repeats: int = 1, 56 repeats: int = 1,
57 dropout: float = 0, 57 dropout: float = 0,
58 interpolation: str = "bicubic", 58 interpolation: str = "bicubic",
59 center_crop: bool = False, 59 center_crop: bool = False,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 generator: Optional[torch.Generator] = None,
63 filter: Optional[Callable[[CSVDataItem], bool]] = None, 63 filter: Optional[Callable[[CSVDataItem], bool]] = None,
64 collate_fn=None, 64 collate_fn=None,
65 num_workers: int = 0 65 num_workers: int = 0
66 ): 66 ):
67 super().__init__() 67 super().__init__()
68 68
@@ -112,7 +112,7 @@ class CSVDataModule():
112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
113 expansions 113 expansions
114 )), 114 )),
115 item["mode"].split(", ") if "mode" in item else [] 115 item["collection"].split(", ") if "collection" in item else []
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -133,7 +133,7 @@ class CSVDataModule():
133 item.prompt, 133 item.prompt,
134 item.cprompt, 134 item.cprompt,
135 item.nprompt, 135 item.nprompt,
136 item.mode, 136 item.collection,
137 ) 137 )
138 for item in items 138 for item in items
139 for i in range(image_multiplier) 139 for i in range(image_multiplier)