From 6c38d0088ece492696a7bc94a5cb43a48289452a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:35:42 +0100 Subject: Fix --- data/csv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index df3ee77..b058a3e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): inputs = unify_input_ids(tokenizer, input_ids) batch = { - "with_prior": torch.tensor(with_prior), + "with_prior": torch.tensor([with_prior] * len(examples)), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, -- cgit v1.2.3-54-g00ecf