From 64c79cc3e7fad49131f90fbb0648b6d5587563e5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 08:43:34 +0100 Subject: Various updated; shuffle prompt content during training --- data/csv.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'data') diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,6 +1,7 @@ import math import torch import json +import numpy as np from pathlib import Path import pytorch_lightning as pl from PIL import Image @@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt +def shuffle_prompt(prompt: str): + def handle_block(block: str): + words = block.split(", ") + np.random.shuffle(words) + return ", ".join(words) + + prompt = prompt.split(". ") + prompt = [handle_block(b) for b in prompt] + np.random.shuffle(prompt) + prompt = ". ".join(prompt) + return prompt + + class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -190,30 +204,27 @@ class CSVDataset(Dataset): item = self.data[i % self.num_instance_images] example = {} - example["prompts"] = item.prompt example["nprompts"] = item.nprompt - example["instance_images"] = self.get_image(item.instance_image_path) - example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) - if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) - example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) return example def __getitem__(self, i): - example = {} unprocessed_example = self.get_example(i) - example["prompts"] = unprocessed_example["prompts"] + example = {} + + example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) return example -- cgit v1.2.3-70-g09d2