From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- data/csv.py | 183 +++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 125 insertions(+), 58 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py @@ -49,7 +49,7 @@ def generate_buckets( max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, - return_tensor: bool = True + return_tensor: bool = True, ): if max_pixels is None: max_pixels = (base_size + step_size) ** 2 @@ -62,7 +62,11 @@ def generate_buckets( for i in range(1, num_buckets + 1): long_side = base_size + i * step_size - short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) + short_side = min( + base_size + - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, + base_size, + ) buckets.append(long_side / short_side) buckets.append(short_side / long_side) @@ -106,7 +110,7 @@ def collate_fn( max_token_id_length: Optional[int], with_guidance: bool, with_prior_preservation: bool, - examples + examples, ): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] @@ -125,7 +129,9 @@ def collate_fn( prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) - negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) + negative_inputs = unify_input_ids( + tokenizer, negative_input_ids, max_token_id_length + ) batch = { "prompt_ids": prompts.input_ids, @@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): nprompt: str collection: list[str] - def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): - return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) + def full_prompt( + self, + dropout: float = 0, + shuffle: bool = False, + npgenerator: Optional[np.random.Generator] = None, + ): + return keywords_to_str( + self.keywords, [self.prompt], dropout, shuffle, npgenerator + ) def keyword_filter( placeholder_tokens: Optional[list[str]], collections: Optional[list[str]], exclude_collections: Optional[list[str]], - item: VlpnDataItem + item: VlpnDataItem, ): full_prompt = item.full_prompt() cond1 = placeholder_tokens is None or any( - token in full_prompt - for token in placeholder_tokens + token in full_prompt for token in placeholder_tokens ) cond2 = collections is None or any( - collection in item.collection - for collection in collections + collection in item.collection for collection in collections ) cond3 = exclude_collections is None or not any( - collection in item.collection - for collection in exclude_collections + collection in item.collection for collection in exclude_collections ) return cond1 and cond2 and cond3 -class VlpnDataModule(): +class VlpnDataModule: def __init__( self, batch_size: int, @@ -222,7 +232,7 @@ class VlpnDataModule(): self.constant_prompt_length = constant_prompt_length self.max_token_id_length = None - + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets @@ -259,23 +269,29 @@ class VlpnDataModule(): nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") collection = item["collection"].split(", ") if "collection" in item else [] - saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) + saturated_keywords = str_to_keywords( + tpl_keywords.format(**keywords), expansions + ) - inverted_tokens = keywords_to_str([ - f"inv_{token}" - for token in self.placeholder_tokens - if token in saturated_keywords - ]) + inverted_tokens = keywords_to_str( + [ + f"inv_{token}" + for token in self.placeholder_tokens + if token in saturated_keywords + ] + ) - items.append(VlpnDataItem( - self.data_root / image, - None, - saturated_keywords, - tpl_prompt.format(**prompt), - tpl_cprompt.format(**prompt), - tpl_nprompt.format(_inv=inverted_tokens, **nprompt), - collection - )) + items.append( + VlpnDataItem( + self.data_root / image, + None, + saturated_keywords, + tpl_prompt.format(**prompt), + tpl_cprompt.format(**prompt), + tpl_nprompt.format(_inv=inverted_tokens, **nprompt), + collection, + ) + ) return items @@ -285,13 +301,16 @@ class VlpnDataModule(): return [item for item in items if self.filter(item)] - def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: + def pad_items( + self, items: list[VlpnDataItem], num_class_images: int = 1 + ) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, - self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", + self.class_root + / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.keywords, item.prompt, item.cprompt, @@ -303,7 +322,7 @@ class VlpnDataModule(): ] def setup(self): - with open(self.data_file, 'rt') as f: + with open(self.data_file, "rt") as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} @@ -312,25 +331,41 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) self.npgenerator.shuffle(items) - + if self.constant_prompt_length: all_input_ids = unify_input_ids( self.tokenizer, - [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] + [ + self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids + for item in items + ], ).input_ids self.max_token_id_length = all_input_ids.shape[1] num_images = len(items) - valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 + valid_set_size = ( + min(self.valid_set_size, num_images) + if self.valid_set_size is not None + else num_images // 10 + ) train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) + collate_fn_ = partial( + collate_fn, + self.dtype, + self.tokenizer, + self.max_token_id_length, + self.with_guidance, + self.num_class_images != 0, + ) if valid_set_size == 0: data_train, data_val = items, items else: - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) + data_train, data_val = random_split( + items, [train_set_size, valid_set_size], generator=self.generator + ) data_train = self.pad_items(data_train, self.num_class_images) @@ -338,17 +373,25 @@ class VlpnDataModule(): data_train *= math.ceil(self.train_set_pad / len(data_train)) self.train_dataset = VlpnDataset( - data_train, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, fill_batch=True, generator=self.generator, - size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, - num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, + data_train, + self.tokenizer, + num_buckets=self.num_buckets, + progressive_buckets=self.progressive_buckets, + bucket_step_size=self.bucket_step_size, + bucket_max_pixels=self.bucket_max_pixels, + batch_size=self.batch_size, + fill_batch=True, + generator=self.generator, + size=self.size, + interpolation=self.interpolation, + color_jitter=self.color_jitter, + num_class_images=self.num_class_images, + dropout=self.dropout, + shuffle=self.shuffle, ) self.train_dataloader = DataLoader( - self.train_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ + self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) if len(data_val) != 0: @@ -358,16 +401,24 @@ class VlpnDataModule(): data_val *= math.ceil(self.valid_set_pad / len(data_val)) self.val_dataset = VlpnDataset( - data_val, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=True, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, generator=self.generator, - size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, + data_val, + self.tokenizer, + num_buckets=self.num_buckets, + progressive_buckets=True, + bucket_step_size=self.bucket_step_size, + bucket_max_pixels=self.bucket_max_pixels, + batch_size=self.batch_size, + generator=self.generator, + size=self.size, + interpolation=self.interpolation, + color_jitter=self.color_jitter, ) self.val_dataloader = DataLoader( self.val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ + batch_size=None, + pin_memory=True, + collate_fn=collate_fn_, ) else: self.val_dataloader = None @@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): self.bucket_item_range = torch.arange(len(self.bucket_items)) - self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() + self.length_ = ( + (self.bucket_assignments.bincount() / self.batch_size) + .ceil() + .long() + .sum() + .item() + ) def get_input_ids(self, text: str): return self.tokenizer(text, padding="do_not_pad").input_ids @@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): worker_info = torch.utils.data.get_worker_info() if self.shuffle: - perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) + perm = torch.randperm( + len(self.bucket_assignments), generator=self.generator + ) self.bucket_items = self.bucket_items[perm] self.bucket_assignments = self.bucket_assignments[perm] @@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): if len(bucket_items) == 0: bucket_items = self.bucket_items[self.bucket_assignments == bucket] - item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] + item_index = bucket_items[ + torch.randint(len(bucket_items), (1,), generator=self.generator) + ] else: item_index = bucket_items[0] mask[self.bucket_item_range[bucket_mask][0]] = False @@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): example["prompt_ids"] = self.get_input_ids(item.full_prompt()) example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) + example["instance_prompt_ids"] = self.get_input_ids( + item.full_prompt(self.dropout, True, self.npgenerator) + ) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example["instance_images"] = image_transforms( + get_image(item.instance_image_path) + ) if self.num_class_images != 0: example["class_prompt_ids"] = self.get_input_ids(item.cprompt) - example["class_images"] = image_transforms(get_image(item.class_image_path)) + example["class_images"] = image_transforms( + get_image(item.class_image_path) + ) batch.append(example) -- cgit v1.2.3-54-g00ecf