summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
commit6c38d0088ece492696a7bc94a5cb43a48289452a (patch)
treed84a8fefd52eba5cbf38e64d34962f34dc6d047d /data
parentCleanup (diff)
downloadtextual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip
Fix
Diffstat (limited to 'data')
-rw-r--r--data/csv.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index df3ee77..b058a3e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples):
121 inputs = unify_input_ids(tokenizer, input_ids) 121 inputs = unify_input_ids(tokenizer, input_ids)
122 122
123 batch = { 123 batch = {
124 "with_prior": torch.tensor(with_prior), 124 "with_prior": torch.tensor([with_prior] * len(examples)),
125 "prompt_ids": prompts.input_ids, 125 "prompt_ids": prompts.input_ids,
126 "nprompt_ids": nprompts.input_ids, 126 "nprompt_ids": nprompts.input_ids,
127 "input_ids": inputs.input_ids, 127 "input_ids": inputs.input_ids,