summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
committerVolpeon <git@volpeon.ink>2022-12-30 13:48:26 +0100
commitdfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0 (patch)
treeda07cbadfad6f54e55e43e2fda21cef80cded5ea /data
parentUpdate (diff)
downloadtextual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.gz
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.tar.bz2
textual-inversion-diff-dfcfd6bc1db6b9eb12c8321d18fc7a461710e7e0.zip
Training script improvements
Diffstat (limited to 'data')
-rw-r--r--data/csv.py15
1 files changed, 4 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py
index 0ad36dc..4da5d64 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,6 +41,7 @@ class CSVDataItem(NamedTuple):
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: str 42 cprompt: str
43 nprompt: str 43 nprompt: str
44 mode: list[str]
44 45
45 46
46class CSVDataModule(): 47class CSVDataModule():
@@ -56,7 +57,6 @@ class CSVDataModule():
56 dropout: float = 0, 57 dropout: float = 0,
57 interpolation: str = "bicubic", 58 interpolation: str = "bicubic",
58 center_crop: bool = False, 59 center_crop: bool = False,
59 mode: Optional[str] = None,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 generator: Optional[torch.Generator] = None,
@@ -81,7 +81,6 @@ class CSVDataModule():
81 self.repeats = repeats 81 self.repeats = repeats
82 self.dropout = dropout 82 self.dropout = dropout
83 self.center_crop = center_crop 83 self.center_crop = center_crop
84 self.mode = mode
85 self.template_key = template_key 84 self.template_key = template_key
86 self.interpolation = interpolation 85 self.interpolation = interpolation
87 self.valid_set_size = valid_set_size 86 self.valid_set_size = valid_set_size
@@ -113,6 +112,7 @@ class CSVDataModule():
113 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
114 expansions 113 expansions
115 )), 114 )),
115 item["mode"].split(", ") if "mode" in item else []
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -133,6 +133,7 @@ class CSVDataModule():
133 item.prompt, 133 item.prompt,
134 item.cprompt, 134 item.cprompt,
135 item.nprompt, 135 item.nprompt,
136 item.mode,
136 ) 137 )
137 for item in items 138 for item in items
138 for i in range(image_multiplier) 139 for i in range(image_multiplier)
@@ -145,20 +146,12 @@ class CSVDataModule():
145 expansions = metadata["expansions"] if "expansions" in metadata else {} 146 expansions = metadata["expansions"] if "expansions" in metadata else {}
146 items = metadata["items"] if "items" in metadata else [] 147 items = metadata["items"] if "items" in metadata else []
147 148
148 if self.mode is not None:
149 items = [
150 item
151 for item in items
152 if "mode" in item and self.mode in item["mode"].split(", ")
153 ]
154 items = self.prepare_items(template, expansions, items) 149 items = self.prepare_items(template, expansions, items)
155 items = self.filter_items(items) 150 items = self.filter_items(items)
156 151
157 num_images = len(items) 152 num_images = len(items)
158 153
159 valid_set_size = int(num_images * 0.1) 154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1)
160 if self.valid_set_size:
161 valid_set_size = min(valid_set_size, self.valid_set_size)
162 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
163 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
164 157