import math import json from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable from PIL import Image import numpy as np import torch from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer from data.keywords import str_to_keywords, keywords_to_str from models.clip.util import unify_input_ids cache = {} interpolations = { "linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, } def get_image(path): if path in cache: return cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") cache[path] = image return image def prepare_tpl_slots(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt def generate_buckets( items: Union[list[str], list[Path]], base_size: int, step_size: int = 64, max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, return_tensor: bool = True ): if max_pixels is None: max_pixels = (base_size + step_size) ** 2 max_pixels = max(max_pixels, base_size * base_size) bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): long_side = base_size + i * step_size short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) buckets.append(long_side / short_side) buckets.append(short_side / long_side) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) for i, item in enumerate(items): image = get_image(item) ratio = image.width / image.height if ratio >= 1: mask = torch.logical_and(buckets >= 1, buckets <= ratio) else: mask = torch.logical_and(buckets <= 1, buckets >= ratio) if not progressive_buckets: inf = torch.zeros_like(buckets) inf[~mask] = math.inf mask = (buckets + inf - ratio).abs().argmin() indices = bucket_indices[mask] if len(indices.shape) == 0: indices = indices.unsqueeze(0) bucket_items += [i] * len(indices) bucket_assignments += indices if return_tensor: bucket_items = torch.tensor(bucket_items) bucket_assignments = torch.tensor(bucket_assignments) else: buckets = buckets.tolist() return buckets, bucket_items, bucket_assignments def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] negative_input_ids = [example["negative_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) prompts = unify_input_ids(tokenizer, prompt_ids) nprompts = unify_input_ids(tokenizer, nprompt_ids) inputs = unify_input_ids(tokenizer, input_ids) negative_inputs = unify_input_ids(tokenizer, negative_input_ids) batch = { "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "negative_input_ids": negative_inputs.attention_mask, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, "negative_attention_mask": negative_inputs.attention_mask, } return batch class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path keywords: list[str] prompt: str cprompt: str nprompt: str collection: list[str] def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) def keyword_filter( placeholder_tokens: Optional[list[str]], collections: Optional[list[str]], exclude_collections: Optional[list[str]], item: VlpnDataItem ): full_prompt = item.full_prompt() cond1 = placeholder_tokens is None or any( token in full_prompt for token in placeholder_tokens ) cond2 = collections is None or any( collection in item.collection for collection in collections ) cond3 = exclude_collections is None or not any( collection in item.collection for collection in exclude_collections ) return cond1 and cond2 and cond3 class VlpnDataModule(): def __init__( self, batch_size: int, data_file: str, tokenizer: CLIPTokenizer, class_subdir: str = "cls", with_guidance: bool = False, num_class_images: int = 1, size: int = 768, num_buckets: int = 0, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = True, template_key: str = "template", placeholder_tokens: list[str] = [], valid_set_size: Optional[int] = None, train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, generator: Optional[torch.Generator] = None, npgenerator: Optional[np.random.Generator] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) self.placeholder_tokens = placeholder_tokens self.num_class_images = num_class_images self.with_guidance = with_guidance self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets self.bucket_step_size = bucket_step_size self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets self.dropout = dropout self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation self.color_jitter = color_jitter self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size self.filter = filter self.batch_size = batch_size self.dtype = dtype self.generator = generator self.npgenerator = npgenerator or np.random.default_rng() def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" tpl_keywords = template["keywords"] if "keywords" in template else "{content}" tpl_prompt = template["prompt"] if "prompt" in template else "{content}" tpl_cprompt = template["cprompt"] if "cprompt" in template else "{content}" tpl_nprompt = template["nprompt"] if "nprompt" in template else "{content}" items = [] for item in data: image = tpl_image.format(item["image"]) keywords = prepare_tpl_slots(item["keywords"] if "keywords" in item else "") prompt = prepare_tpl_slots(item["prompt"] if "prompt" in item else "") nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") collection = item["collection"].split(", ") if "collection" in item else [] saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) inverted_tokens = keywords_to_str([ f"inv_{token}" for token in self.placeholder_tokens if token in saturated_keywords ]) items.append(VlpnDataItem( self.data_root / image, None, saturated_keywords, tpl_prompt.format(**prompt), tpl_cprompt.format(**prompt), tpl_nprompt.format(_inv=inverted_tokens, **nprompt), collection )) return items def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.keywords, item.prompt, item.cprompt, item.nprompt, item.collection, ) for item in items for i in range(image_multiplier) ] def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) self.npgenerator.shuffle(items) num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) data_train = self.pad_items(data_train, self.num_class_images) if len(data_train) < self.train_set_pad: data_train *= math.ceil(self.train_set_pad / len(data_train)) self.train_dataset = VlpnDataset( data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, fill_batch=True, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) self.train_dataloader = DataLoader( self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) if len(data_val) != 0: data_val = self.pad_items(data_val) if len(data_val) < self.valid_set_pad: data_val *= math.ceil(self.valid_set_pad / len(data_val)) self.val_dataset = VlpnDataset( data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, ) self.val_dataloader = DataLoader( self.val_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) else: self.val_dataloader = None class VlpnDataset(IterableDataset): def __init__( self, items: list[VlpnDataItem], tokenizer: CLIPTokenizer, num_buckets: int = 1, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, fill_batch: bool = False, num_class_images: int = 0, size: int = 768, dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", color_jitter: bool = True, generator: Optional[torch.Generator] = None, npgenerator: Optional[np.random.Generator] = None, ): self.items = items self.batch_size = batch_size self.fill_batch = fill_batch self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size self.dropout = dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] self.color_jitter = color_jitter self.generator = generator self.npgenerator = npgenerator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in self.items], base_size=size, step_size=bucket_step_size, num_buckets=num_buckets, max_pixels=bucket_max_pixels, progressive_buckets=progressive_buckets, ) self.bucket_item_range = torch.arange(len(self.bucket_items)) self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def get_input_ids(self, text: str): return self.tokenizer(text, padding="do_not_pad").input_ids def __len__(self): return self.length_ def __iter__(self): worker_info = torch.utils.data.get_worker_info() if self.shuffle: perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) self.bucket_items = self.bucket_items[perm] self.bucket_assignments = self.bucket_assignments[perm] image_transforms = None mask = torch.ones_like(self.bucket_assignments, dtype=bool) bucket = -1 batch = [] batch_size = self.batch_size if worker_info is not None: batch_size = math.ceil(batch_size / worker_info.num_workers) worker_batch = math.ceil(len(self) / worker_info.num_workers) start = worker_info.id * worker_batch end = start + worker_batch mask[:start] = False mask[end:] = False while mask.any() or len(batch) != 0: if len(batch) >= batch_size: yield batch batch = [] continue bucket_mask = mask.logical_and(self.bucket_assignments == bucket) bucket_items = self.bucket_items[bucket_mask] if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch: yield batch batch = [] continue if len(bucket_items) == 0 and len(batch) == 0: bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = int(self.size * ratio) if ratio > 1 else self.size height = int(self.size / ratio) if ratio < 1 else self.size image_transforms = [ transforms.Resize(self.size, interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), ] if self.color_jitter: image_transforms += [ transforms.ColorJitter(0.2, 0.1), ] image_transforms += [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] image_transforms = transforms.Compose(image_transforms) continue if len(bucket_items) == 0: bucket_items = self.bucket_items[self.bucket_assignments == bucket] item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] else: item_index = bucket_items[0] mask[self.bucket_item_range[bucket_mask][0]] = False item = self.items[item_index] example = {} example["prompt_ids"] = self.get_input_ids(item.full_prompt()) example["nprompt_ids"] = self.get_input_ids(item.nprompt) example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: example["class_prompt_ids"] = self.get_input_ids(item.cprompt) example["class_images"] = image_transforms(get_image(item.class_image_path)) batch.append(example)