import math import torch import json from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable from PIL import Image from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.util import unify_input_ids cache = {} interpolations = { "linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, "lanczos": transforms.InterpolationMode.LANCZOS, } def get_image(path): if path in cache: return cache[path] image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") cache[path] = image return image def prepare_prompt(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt def generate_buckets( items: Union[list[str], list[Path]], base_size: int, step_size: int = 64, max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, return_tensor: bool = True ): if max_pixels is None: max_pixels = (base_size + step_size) ** 2 max_pixels = max(max_pixels, base_size * base_size) bucket_items: list[int] = [] bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): long_side = base_size + i * step_size short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) buckets.append(long_side / short_side) buckets.append(short_side / long_side) buckets = torch.tensor(buckets) bucket_indices = torch.arange(len(buckets)) for i, item in enumerate(items): image = get_image(item) ratio = image.width / image.height if ratio >= 1: mask = torch.logical_and(buckets >= 1, buckets <= ratio) else: mask = torch.logical_and(buckets <= 1, buckets >= ratio) if not progressive_buckets: inf = torch.zeros_like(buckets) inf[~mask] = math.inf mask = (buckets + inf - ratio).abs().argmin() indices = bucket_indices[mask] if len(indices.shape) == 0: indices = indices.unsqueeze(0) bucket_items += [i] * len(indices) bucket_assignments += indices if return_tensor: bucket_items = torch.tensor(bucket_items) bucket_assignments = torch.tensor(bucket_assignments) else: buckets = buckets.tolist() return buckets, bucket_items, bucket_assignments def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) prompts = unify_input_ids(tokenizer, prompt_ids) nprompts = unify_input_ids(tokenizer, nprompt_ids) inputs = unify_input_ids(tokenizer, input_ids) batch = { "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } return batch class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] cprompt: str nprompt: str collection: list[str] def keyword_filter( placeholder_tokens: Optional[list[str]], collection: Optional[list[str]], exclude_collections: Optional[list[str]], item: VlpnDataItem ): cond1 = placeholder_tokens is None or any( keyword in part for keyword in placeholder_tokens for part in item.prompt ) cond2 = collection is None or collection in item.collection cond3 = exclude_collections is None or not any( collection in item.collection for collection in exclude_collections ) return cond1 and cond2 and cond3 class VlpnDataModule(): def __init__( self, batch_size: int, data_file: str, tokenizer: CLIPTokenizer, class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, num_buckets: int = 0, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): super().__init__() self.data_file = Path(data_file) if not self.data_file.is_file(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets self.bucket_step_size = bucket_step_size self.bucket_max_pixels = bucket_max_pixels self.progressive_buckets = progressive_buckets self.dropout = dropout self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size self.seed = seed self.filter = filter self.batch_size = batch_size self.dtype = dtype def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ VlpnDataItem( self.data_root / image.format(item["image"]), None, prompt_to_keywords( prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions ), keywords_to_prompt(prompt_to_keywords( cprompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions )), keywords_to_prompt(prompt_to_keywords( nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), item["collection"].split(", ") if "collection" in item else [] ) for item in data ] def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.prompt, item.cprompt, item.nprompt, item.collection, ) for item in items for i in range(image_multiplier) ] def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} items = metadata["items"] if "items" in metadata else [] items = self.prepare_items(template, expansions, items) items = self.filter_items(items) num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size generator = torch.Generator(device="cpu") if self.seed is not None: generator = generator.manual_seed(self.seed) collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items[:self.batch_size] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) data_train = self.pad_items(data_train, self.num_class_images) if len(data_train) < self.train_set_pad: data_train *= math.ceil(self.train_set_pad / len(data_train)) self.train_dataset = VlpnDataset( data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, fill_batch=True, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) self.train_dataloader = DataLoader( self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) if len(data_val) != 0: data_val = self.pad_items(data_val) if len(data_val) < self.valid_set_pad: data_val *= math.ceil(self.valid_set_pad / len(data_val)) self.val_dataset = VlpnDataset( data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) self.val_dataloader = DataLoader( self.val_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) else: self.val_dataloader = None class VlpnDataset(IterableDataset): def __init__( self, items: list[VlpnDataItem], tokenizer: CLIPTokenizer, num_buckets: int = 1, bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, batch_size: int = 1, fill_batch: bool = False, num_class_images: int = 0, size: int = 768, dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): self.items = items self.batch_size = batch_size self.fill_batch = fill_batch self.tokenizer = tokenizer self.num_class_images = num_class_images self.size = size self.dropout = dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in self.items], base_size=size, step_size=bucket_step_size, num_buckets=num_buckets, max_pixels=bucket_max_pixels, progressive_buckets=progressive_buckets, ) self.bucket_item_range = torch.arange(len(self.bucket_items)) self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def get_input_ids(self, text: str): return self.tokenizer(text, padding="do_not_pad").input_ids def __len__(self): return self.length_ def __iter__(self): worker_info = torch.utils.data.get_worker_info() if self.shuffle: perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) self.bucket_items = self.bucket_items[perm] self.bucket_assignments = self.bucket_assignments[perm] image_transforms = None mask = torch.ones_like(self.bucket_assignments, dtype=bool) bucket = -1 batch = [] batch_size = self.batch_size if worker_info is not None: batch_size = math.ceil(batch_size / worker_info.num_workers) worker_batch = math.ceil(len(self) / worker_info.num_workers) start = worker_info.id * worker_batch end = start + worker_batch mask[:start] = False mask[end:] = False while mask.any() or len(batch) != 0: if len(batch) >= batch_size: yield batch batch = [] continue bucket_mask = mask.logical_and(self.bucket_assignments == bucket) bucket_items = self.bucket_items[bucket_mask] if len(bucket_items) == 0 and len(batch) != 0 and not self.fill_batch: yield batch batch = [] continue if len(bucket_items) == 0 and len(batch) == 0: bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = int(self.size * ratio) if ratio > 1 else self.size height = int(self.size / ratio) if ratio < 1 else self.size image_transforms = transforms.Compose( [ transforms.Resize(self.size, interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) continue if len(bucket_items) == 0: bucket_items = self.bucket_items[self.bucket_assignments == bucket] item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] else: item_index = bucket_items[0] mask[self.bucket_item_range[bucket_mask][0]] = False item = self.items[item_index] example = {} example["prompt_ids"] = self.get_input_ids(keywords_to_prompt(item.prompt)) example["nprompt_ids"] = self.get_input_ids(item.nprompt) example["instance_prompt_ids"] = self.get_input_ids( keywords_to_prompt(item.prompt, self.dropout, True) ) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) if self.num_class_images != 0: example["class_prompt_ids"] = self.get_input_ids(item.cprompt) example["class_images"] = image_transforms(get_image(item.class_image_path)) batch.append(example)