diff options
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/data/csv.py b/data/csv.py index b058a3e..5de3ac7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -100,28 +100,25 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): | 103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): |
| 104 | with_prior = all("class_prompt_ids" in example for example in examples) | ||
| 105 | |||
| 106 | prompt_ids = [example["prompt_ids"] for example in examples] | 104 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 105 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 108 | 106 | ||
| 109 | input_ids = [example["instance_prompt_ids"] for example in examples] | 107 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 110 | pixel_values = [example["instance_images"] for example in examples] | 108 | pixel_values = [example["instance_images"] for example in examples] |
| 111 | 109 | ||
| 112 | if with_prior: | 110 | if with_prior_preservation: |
| 113 | input_ids += [example["class_prompt_ids"] for example in examples] | 111 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 114 | pixel_values += [example["class_images"] for example in examples] | 112 | pixel_values += [example["class_images"] for example in examples] |
| 115 | 113 | ||
| 116 | pixel_values = torch.stack(pixel_values) | 114 | pixel_values = torch.stack(pixel_values) |
| 117 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 115 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) |
| 118 | 116 | ||
| 119 | prompts = unify_input_ids(tokenizer, prompt_ids) | 117 | prompts = unify_input_ids(tokenizer, prompt_ids) |
| 120 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 118 | nprompts = unify_input_ids(tokenizer, nprompt_ids) |
| 121 | inputs = unify_input_ids(tokenizer, input_ids) | 119 | inputs = unify_input_ids(tokenizer, input_ids) |
| 122 | 120 | ||
| 123 | batch = { | 121 | batch = { |
| 124 | "with_prior": torch.tensor([with_prior] * len(examples)), | ||
| 125 | "prompt_ids": prompts.input_ids, | 122 | "prompt_ids": prompts.input_ids, |
| 126 | "nprompt_ids": nprompts.input_ids, | 123 | "nprompt_ids": nprompts.input_ids, |
| 127 | "input_ids": inputs.input_ids, | 124 | "input_ids": inputs.input_ids, |
| @@ -285,7 +282,7 @@ class VlpnDataModule(): | |||
| 285 | size=self.size, interpolation=self.interpolation, | 282 | size=self.size, interpolation=self.interpolation, |
| 286 | ) | 283 | ) |
| 287 | 284 | ||
| 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) | 285 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 289 | 286 | ||
| 290 | self.train_dataloader = DataLoader( | 287 | self.train_dataloader = DataLoader( |
| 291 | train_dataset, | 288 | train_dataset, |
