diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/csv.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index a6cd065..d52d251 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -104,11 +104,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
105 | 105 | ||
106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] |
107 | negative_input_ids = [example["negative_prompt_ids"] for example in examples] | ||
107 | pixel_values = [example["instance_images"] for example in examples] | 108 | pixel_values = [example["instance_images"] for example in examples] |
108 | 109 | ||
109 | if with_guidance: | 110 | if with_prior_preservation: |
110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
111 | elif with_prior_preservation: | ||
112 | input_ids += [example["class_prompt_ids"] for example in examples] | 111 | input_ids += [example["class_prompt_ids"] for example in examples] |
113 | pixel_values += [example["class_images"] for example in examples] | 112 | pixel_values += [example["class_images"] for example in examples] |
114 | 113 | ||
@@ -118,13 +117,16 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 117 | prompts = unify_input_ids(tokenizer, prompt_ids) |
119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 118 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
120 | inputs = unify_input_ids(tokenizer, input_ids) | 119 | inputs = unify_input_ids(tokenizer, input_ids) |
120 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | ||
121 | 121 | ||
122 | batch = { | 122 | batch = { |
123 | "prompt_ids": prompts.input_ids, | 123 | "prompt_ids": prompts.input_ids, |
124 | "nprompt_ids": nprompts.input_ids, | 124 | "nprompt_ids": nprompts.input_ids, |
125 | "input_ids": inputs.input_ids, | 125 | "input_ids": inputs.input_ids, |
126 | "negative_input_ids": negative_inputs.attention_mask, | ||
126 | "pixel_values": pixel_values, | 127 | "pixel_values": pixel_values, |
127 | "attention_mask": inputs.attention_mask, | 128 | "attention_mask": inputs.attention_mask, |
129 | "negative_attention_mask": negative_inputs.attention_mask, | ||
128 | } | 130 | } |
129 | 131 | ||
130 | return batch | 132 | return batch |