diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 09:35:42 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 09:35:42 +0100 |
| commit | 6c38d0088ece492696a7bc94a5cb43a48289452a (patch) | |
| tree | d84a8fefd52eba5cbf38e64d34962f34dc6d047d /data | |
| parent | Cleanup (diff) | |
| download | textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2 textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip | |
Fix
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py index df3ee77..b058a3e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -121,7 +121,7 @@ def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): | |||
| 121 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
| 122 | 122 | ||
| 123 | batch = { | 123 | batch = { |
| 124 | "with_prior": torch.tensor(with_prior), | 124 | "with_prior": torch.tensor([with_prior] * len(examples)), |
| 125 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
| 126 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
| 127 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
