summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index df3ee77..b058a3e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples):
121 inputs = unify_input_ids(tokenizer, input_ids) 121 inputs = unify_input_ids(tokenizer, input_ids)
122 122
123 batch = { 123 batch = {
124 "with_prior": torch.tensor(with_prior), 124 "with_prior": torch.tensor([with_prior] * len(examples)),
125 "prompt_ids": prompts.input_ids, 125 "prompt_ids": prompts.input_ids,
126 "nprompt_ids": nprompts.input_ids, 126 "nprompt_ids": nprompts.input_ids,
127 "input_ids": inputs.input_ids, 127 "input_ids": inputs.input_ids,