From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- data/csv.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,20 +100,16 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn( - num_class_images: int, - weight_dtype: torch.dtype, - tokenizer: CLIPTokenizer, - examples -): +def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): + with_prior = all("class_prompt_ids" in example for example in examples) + prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - # concat class and instance examples for prior preservation - if num_class_images != 0 and "class_prompt_ids" in examples[0]: + if with_prior: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -125,6 +121,7 @@ def collate_fn( inputs = unify_input_ids(tokenizer, input_ids) batch = { + "with_prior": torch.tensor(with_prior), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, @@ -166,7 +163,6 @@ class VlpnDataModule(): seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, - num_workers: int = 0 ): super().__init__() @@ -194,7 +190,6 @@ class VlpnDataModule(): self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter - self.num_workers = num_workers self.batch_size = batch_size self.dtype = dtype @@ -290,16 +285,16 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) self.train_dataloader = DataLoader( train_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) self.val_dataloader = DataLoader( val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) -- cgit v1.2.3-54-g00ecf