summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-01 16:19:01 +0100
committerVolpeon <git@volpeon.ink>2022-11-01 16:19:01 +0100
commitb2c3389e9c6375d9081625e75a99de98395f8e77 (patch)
treed230b417314960e8705abd2eeaa3b55d9b70c754 /data
parentFix (diff)
downloadtextual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.gz
textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.tar.bz2
textual-inversion-diff-b2c3389e9c6375d9081625e75a99de98395f8e77.zip
Update
Diffstat (limited to 'data')
-rw-r--r--data/csv.py11
1 files changed, 1 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py
index 6bd7f9b..793fbf8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -150,7 +150,6 @@ class CSVDataset(Dataset):
150 self.class_identifier = class_identifier 150 self.class_identifier = class_identifier
151 self.num_class_images = num_class_images 151 self.num_class_images = num_class_images
152 self.image_cache = {} 152 self.image_cache = {}
153 self.input_id_cache = {}
154 153
155 self.num_instance_images = len(self.data) 154 self.num_instance_images = len(self.data)
156 self._length = self.num_instance_images * repeats 155 self._length = self.num_instance_images * repeats
@@ -185,15 +184,7 @@ class CSVDataset(Dataset):
185 return image 184 return image
186 185
187 def get_input_ids(self, prompt, identifier): 186 def get_input_ids(self, prompt, identifier):
188 prompt = prompt.format(identifier) 187 return self.prompt_processor.get_input_ids(prompt.format(identifier))
189
190 if prompt in self.input_id_cache:
191 return self.input_id_cache[prompt]
192
193 input_ids = self.prompt_processor.get_input_ids(prompt)
194 self.input_id_cache[prompt] = input_ids
195
196 return input_ids
197 188
198 def get_example(self, i): 189 def get_example(self, i):
199 item = self.data[i % self.num_instance_images] 190 item = self.data[i % self.num_instance_images]