summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
committerVolpeon <git@volpeon.ink>2022-10-07 08:34:07 +0200
commit14beba63391e1ddc9a145bb638d9306086ad1a5c (patch)
tree4e7d5126359c4ab6ab6dff3c2af537d659e276e8 /data
parentUpdate (diff)
downloadtextual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.gz
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.tar.bz2
textual-inversion-diff-14beba63391e1ddc9a145bb638d9306086ad1a5c.zip
Training: Create multiple class images per training image
Diffstat (limited to 'data')
-rw-r--r--data/csv.py54
1 files changed, 40 insertions, 14 deletions
diff --git a/data/csv.py b/data/csv.py
index abd329d..dcaf7d3 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,5 +1,3 @@
1import math
2import os
3import pandas as pd 1import pandas as pd
4from pathlib import Path 2from pathlib import Path
5import pytorch_lightning as pl 3import pytorch_lightning as pl
@@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule):
16 instance_identifier, 14 instance_identifier,
17 class_identifier=None, 15 class_identifier=None,
18 class_subdir="db_cls", 16 class_subdir="db_cls",
17 num_class_images=2,
19 size=512, 18 size=512,
20 repeats=100, 19 repeats=100,
21 interpolation="bicubic", 20 interpolation="bicubic",
@@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule):
33 self.data_root = self.data_file.parent 32 self.data_root = self.data_file.parent
34 self.class_root = self.data_root.joinpath(class_subdir) 33 self.class_root = self.data_root.joinpath(class_subdir)
35 self.class_root.mkdir(parents=True, exist_ok=True) 34 self.class_root.mkdir(parents=True, exist_ok=True)
35 self.num_class_images = num_class_images
36 36
37 self.tokenizer = tokenizer 37 self.tokenizer = tokenizer
38 self.instance_identifier = instance_identifier 38 self.instance_identifier = instance_identifier
@@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule):
48 48
49 def prepare_data(self): 49 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 50 metadata = pd.read_csv(self.data_file)
51 instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] 51 instance_image_paths = [
52 class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] 52 self.data_root.joinpath(f)
53 prompts = metadata['prompt'].values 53 for f in metadata['image'].values
54 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) 54 for i in range(self.num_class_images)
55 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) 55 ]
56 self.data = [(i, c, p, n) 56 class_image_paths = [
57 for i, c, p, n, s 57 self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}")
58 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) 58 for f in metadata['image'].values
59 if s != "x"] 59 for i in range(self.num_class_images)
60 ]
61 prompts = [
62 prompt
63 for prompt in metadata['prompt'].values
64 for i in range(self.num_class_images)
65 ]
66 nprompts = [
67 nprompt
68 for nprompt in metadata['nprompt'].values
69 for i in range(self.num_class_images)
70 ] if 'nprompt' in metadata else [""] * len(instance_image_paths)
71 skips = [
72 skip
73 for skip in metadata['skip'].values
74 for i in range(self.num_class_images)
75 ] if 'skip' in metadata else [""] * len(instance_image_paths)
76 self.data = [
77 (i, c, p, n)
78 for i, c, p, n, s
79 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
80 if s != "x"
81 ]
60 82
61 def setup(self, stage=None): 83 def setup(self, stage=None):
62 valid_set_size = int(len(self.data) * 0.2) 84 valid_set_size = int(len(self.data) * 0.2)
@@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule):
69 91
70 train_dataset = CSVDataset(self.data_train, self.tokenizer, 92 train_dataset = CSVDataset(self.data_train, self.tokenizer,
71 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
94 num_class_images=self.num_class_images,
72 size=self.size, interpolation=self.interpolation, 95 size=self.size, interpolation=self.interpolation,
73 center_crop=self.center_crop, repeats=self.repeats) 96 center_crop=self.center_crop, repeats=self.repeats)
74 val_dataset = CSVDataset(self.data_val, self.tokenizer, 97 val_dataset = CSVDataset(self.data_val, self.tokenizer,
@@ -93,6 +116,7 @@ class CSVDataset(Dataset):
93 tokenizer, 116 tokenizer,
94 instance_identifier, 117 instance_identifier,
95 class_identifier=None, 118 class_identifier=None,
119 num_class_images=2,
96 size=512, 120 size=512,
97 repeats=1, 121 repeats=1,
98 interpolation="bicubic", 122 interpolation="bicubic",
@@ -103,6 +127,7 @@ class CSVDataset(Dataset):
103 self.tokenizer = tokenizer 127 self.tokenizer = tokenizer
104 self.instance_identifier = instance_identifier 128 self.instance_identifier = instance_identifier
105 self.class_identifier = class_identifier 129 self.class_identifier = class_identifier
130 self.num_class_images = num_class_images
106 self.cache = {} 131 self.cache = {}
107 132
108 self.num_instance_images = len(self.data) 133 self.num_instance_images = len(self.data)
@@ -128,9 +153,10 @@ class CSVDataset(Dataset):
128 153
129 def get_example(self, i): 154 def get_example(self, i):
130 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] 155 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images]
156 cache_key = f"{instance_image_path}_{class_image_path}"
131 157
132 if instance_image_path in self.cache: 158 if cache_key in self.cache:
133 return self.cache[instance_image_path] 159 return self.cache[cache_key]
134 160
135 example = {} 161 example = {}
136 162
@@ -149,7 +175,7 @@ class CSVDataset(Dataset):
149 max_length=self.tokenizer.model_max_length, 175 max_length=self.tokenizer.model_max_length,
150 ).input_ids 176 ).input_ids
151 177
152 if self.class_identifier is not None: 178 if self.num_class_images != 0:
153 class_image = Image.open(class_image_path) 179 class_image = Image.open(class_image_path)
154 if not class_image.mode == "RGB": 180 if not class_image.mode == "RGB":
155 class_image = class_image.convert("RGB") 181 class_image = class_image.convert("RGB")