From 8f2b8e8d309470babd9b853fde8f0a081366deae Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 11 Dec 2022 12:59:13 +0100 Subject: Training improvements such as tag dropout --- data/csv.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'data/csv.py') diff --git a/data/csv.py b/data/csv.py index 23b5299..9125212 100644 --- a/data/csv.py +++ b/data/csv.py @@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def shuffle_prompt(prompt: str): +def shuffle_prompt(prompt: str, dropout: float = 0): def handle_block(block: str): words = block.split(", ") + words = [w for w in words if w != ""] + if dropout != 0: + words = [w for w in words if np.random.random() > dropout] np.random.shuffle(words) return ", ".join(words) prompt = prompt.split(". ") - prompt = [handle_block(b) for b in prompt] + prompt = [handle_block(b) for b in prompt if b != ""] np.random.shuffle(prompt) prompt = ". ".join(prompt) return prompt @@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule): num_class_images: int = 100, size: int = 512, repeats: int = 1, + dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, valid_set_size: Optional[int] = None, @@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule): self.class_identifier = class_identifier self.size = size self.repeats = repeats + self.dropout = dropout self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule): instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats) + center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, @@ -153,6 +158,7 @@ class CSVDataset(Dataset): num_class_images: int = 0, size: int = 512, repeats: int = 1, + dropout: float = 0, interpolation: str = "bicubic", center_crop: bool = False, ): @@ -163,6 +169,7 @@ class CSVDataset(Dataset): self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images + self.dropout = dropout self.image_cache = {} self.num_instance_images = len(self.data) -- cgit v1.2.3-54-g00ecf