From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- data/csv.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index b058a3e..5de3ac7 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,28 +100,25 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): - with_prior = all("class_prompt_ids" in example for example in examples) - +def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - if with_prior: + if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) prompts = unify_input_ids(tokenizer, prompt_ids) nprompts = unify_input_ids(tokenizer, nprompt_ids) inputs = unify_input_ids(tokenizer, input_ids) batch = { - "with_prior": torch.tensor([with_prior] * len(examples)), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, @@ -285,7 +282,7 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) self.train_dataloader = DataLoader( train_dataset, -- cgit v1.2.3-54-g00ecf