From 89d471652644f449966a0cd944041c98dab7f66c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 07:25:24 +0100 Subject: Code deduplication --- data/csv.py | 47 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 9ad7dd6..f5fc8e6 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,7 +1,7 @@ import math import torch import json -import copy +from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable @@ -99,6 +99,41 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments +def collate_fn( + num_class_images: int, + weight_dtype: torch.dtype, + prompt_processor: PromptProcessor, + examples +): + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if num_class_images != 0 and "class_prompt_ids" in examples[0]: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) + inputs = prompt_processor.unify_input_ids(input_ids) + + batch = { + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, + "input_ids": inputs.input_ids, + "pixel_values": pixel_values, + "attention_mask": inputs.attention_mask, + } + + return batch + + class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -129,7 +164,7 @@ class VlpnDataModule(): valid_set_repeat: int = 1, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, - collate_fn=None, + dtype: torch.dtype = torch.float32, num_workers: int = 0 ): super().__init__() @@ -158,9 +193,9 @@ class VlpnDataModule(): self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter - self.collate_fn = collate_fn self.num_workers = num_workers self.batch_size = batch_size + self.dtype = dtype def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" @@ -254,14 +289,16 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) + collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.prompt_processor) + self.train_dataloader = DataLoader( train_dataset, - batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers ) self.val_dataloader = DataLoader( val_dataset, - batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers ) -- cgit v1.2.3-70-g09d2