From 8f2b8e8d309470babd9b853fde8f0a081366deae Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sun, 11 Dec 2022 12:59:13 +0100
Subject: Training improvements such as tag dropout

---
 data/csv.py | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

(limited to 'data')

diff --git a/data/csv.py b/data/csv.py
index 23b5299..9125212 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
     return {"content": prompt} if isinstance(prompt, str) else prompt
 
 
-def shuffle_prompt(prompt: str):
+def shuffle_prompt(prompt: str, dropout: float = 0):
     def handle_block(block: str):
         words = block.split(", ")
+        words = [w for w in words if w != ""]
+        if dropout != 0:
+            words = [w for w in words if np.random.random() > dropout]
         np.random.shuffle(words)
         return ", ".join(words)
 
     prompt = prompt.split(". ")
-    prompt = [handle_block(b) for b in prompt]
+    prompt = [handle_block(b) for b in prompt if b != ""]
     np.random.shuffle(prompt)
     prompt = ". ".join(prompt)
     return prompt
@@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule):
             num_class_images: int = 100,
             size: int = 512,
             repeats: int = 1,
+            dropout: float = 0,
             interpolation: str = "bicubic",
             center_crop: bool = False,
             valid_set_size: Optional[int] = None,
@@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule):
         self.class_identifier = class_identifier
         self.size = size
         self.repeats = repeats
+        self.dropout = dropout
         self.center_crop = center_crop
         self.interpolation = interpolation
         self.valid_set_size = valid_set_size
@@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule):
                                    instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
                                    num_class_images=self.num_class_images,
                                    size=self.size, interpolation=self.interpolation,
-                                   center_crop=self.center_crop, repeats=self.repeats)
+                                   center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout)
         val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
                                  instance_identifier=self.instance_identifier,
                                  size=self.size, interpolation=self.interpolation,
@@ -153,6 +158,7 @@ class CSVDataset(Dataset):
         num_class_images: int = 0,
         size: int = 512,
         repeats: int = 1,
+        dropout: float = 0,
         interpolation: str = "bicubic",
         center_crop: bool = False,
     ):
@@ -163,6 +169,7 @@ class CSVDataset(Dataset):
         self.instance_identifier = instance_identifier
         self.class_identifier = class_identifier
         self.num_class_images = num_class_images
+        self.dropout = dropout
         self.image_cache = {}
 
         self.num_instance_images = len(self.data)
-- 
cgit v1.2.3-70-g09d2