summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py11
1 files changed, 4 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py
index b058a3e..5de3ac7 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,28 +100,25 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): 103def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples):
104 with_prior = all("class_prompt_ids" in example for example in examples)
105
106 prompt_ids = [example["prompt_ids"] for example in examples] 104 prompt_ids = [example["prompt_ids"] for example in examples]
107 nprompt_ids = [example["nprompt_ids"] for example in examples] 105 nprompt_ids = [example["nprompt_ids"] for example in examples]
108 106
109 input_ids = [example["instance_prompt_ids"] for example in examples] 107 input_ids = [example["instance_prompt_ids"] for example in examples]
110 pixel_values = [example["instance_images"] for example in examples] 108 pixel_values = [example["instance_images"] for example in examples]
111 109
112 if with_prior: 110 if with_prior_preservation:
113 input_ids += [example["class_prompt_ids"] for example in examples] 111 input_ids += [example["class_prompt_ids"] for example in examples]
114 pixel_values += [example["class_images"] for example in examples] 112 pixel_values += [example["class_images"] for example in examples]
115 113
116 pixel_values = torch.stack(pixel_values) 114 pixel_values = torch.stack(pixel_values)
117 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 115 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format)
118 116
119 prompts = unify_input_ids(tokenizer, prompt_ids) 117 prompts = unify_input_ids(tokenizer, prompt_ids)
120 nprompts = unify_input_ids(tokenizer, nprompt_ids) 118 nprompts = unify_input_ids(tokenizer, nprompt_ids)
121 inputs = unify_input_ids(tokenizer, input_ids) 119 inputs = unify_input_ids(tokenizer, input_ids)
122 120
123 batch = { 121 batch = {
124 "with_prior": torch.tensor([with_prior] * len(examples)),
125 "prompt_ids": prompts.input_ids, 122 "prompt_ids": prompts.input_ids,
126 "nprompt_ids": nprompts.input_ids, 123 "nprompt_ids": nprompts.input_ids,
127 "input_ids": inputs.input_ids, 124 "input_ids": inputs.input_ids,
@@ -285,7 +282,7 @@ class VlpnDataModule():
285 size=self.size, interpolation=self.interpolation, 282 size=self.size, interpolation=self.interpolation,
286 ) 283 )
287 284
288 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) 285 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
289 286
290 self.train_dataloader = DataLoader( 287 self.train_dataloader = DataLoader(
291 train_dataset, 288 train_dataset,