summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-11 12:59:13 +0100
committerVolpeon <git@volpeon.ink>2022-12-11 12:59:13 +0100
commit8f2b8e8d309470babd9b853fde8f0a081366deae (patch)
tree1374e791705e31fa77fefeb5001aad204cdf3224 /data
parentSupport attention_mask of text encoder (diff)
downloadtextual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.gz
textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.bz2
textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.zip
Training improvements such as tag dropout
Diffstat (limited to 'data')
-rw-r--r--data/csv.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 23b5299..9125212 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
16 return {"content": prompt} if isinstance(prompt, str) else prompt 16 return {"content": prompt} if isinstance(prompt, str) else prompt
17 17
18 18
19def shuffle_prompt(prompt: str): 19def shuffle_prompt(prompt: str, dropout: float = 0):
20 def handle_block(block: str): 20 def handle_block(block: str):
21 words = block.split(", ") 21 words = block.split(", ")
22 words = [w for w in words if w != ""]
23 if dropout != 0:
24 words = [w for w in words if np.random.random() > dropout]
22 np.random.shuffle(words) 25 np.random.shuffle(words)
23 return ", ".join(words) 26 return ", ".join(words)
24 27
25 prompt = prompt.split(". ") 28 prompt = prompt.split(". ")
26 prompt = [handle_block(b) for b in prompt] 29 prompt = [handle_block(b) for b in prompt if b != ""]
27 np.random.shuffle(prompt) 30 np.random.shuffle(prompt)
28 prompt = ". ".join(prompt) 31 prompt = ". ".join(prompt)
29 return prompt 32 return prompt
@@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule):
48 num_class_images: int = 100, 51 num_class_images: int = 100,
49 size: int = 512, 52 size: int = 512,
50 repeats: int = 1, 53 repeats: int = 1,
54 dropout: float = 0,
51 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
52 center_crop: bool = False, 56 center_crop: bool = False,
53 valid_set_size: Optional[int] = None, 57 valid_set_size: Optional[int] = None,
@@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule):
72 self.class_identifier = class_identifier 76 self.class_identifier = class_identifier
73 self.size = size 77 self.size = size
74 self.repeats = repeats 78 self.repeats = repeats
79 self.dropout = dropout
75 self.center_crop = center_crop 80 self.center_crop = center_crop
76 self.interpolation = interpolation 81 self.interpolation = interpolation
77 self.valid_set_size = valid_set_size 82 self.valid_set_size = valid_set_size
@@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule):
123 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 128 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
124 num_class_images=self.num_class_images, 129 num_class_images=self.num_class_images,
125 size=self.size, interpolation=self.interpolation, 130 size=self.size, interpolation=self.interpolation,
126 center_crop=self.center_crop, repeats=self.repeats) 131 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout)
127 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 132 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
128 instance_identifier=self.instance_identifier, 133 instance_identifier=self.instance_identifier,
129 size=self.size, interpolation=self.interpolation, 134 size=self.size, interpolation=self.interpolation,
@@ -153,6 +158,7 @@ class CSVDataset(Dataset):
153 num_class_images: int = 0, 158 num_class_images: int = 0,
154 size: int = 512, 159 size: int = 512,
155 repeats: int = 1, 160 repeats: int = 1,
161 dropout: float = 0,
156 interpolation: str = "bicubic", 162 interpolation: str = "bicubic",
157 center_crop: bool = False, 163 center_crop: bool = False,
158 ): 164 ):
@@ -163,6 +169,7 @@ class CSVDataset(Dataset):
163 self.instance_identifier = instance_identifier 169 self.instance_identifier = instance_identifier
164 self.class_identifier = class_identifier 170 self.class_identifier = class_identifier
165 self.num_class_images = num_class_images 171 self.num_class_images = num_class_images
172 self.dropout = dropout
166 self.image_cache = {} 173 self.image_cache = {}
167 174
168 self.num_instance_images = len(self.data) 175 self.num_instance_images = len(self.data)