summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py112
-rw-r--r--data/dreambooth/prompt.py4
2 files changed, 55 insertions, 61 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index c0b0067..4ebdc13 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule):
13 batch_size, 13 batch_size,
14 data_file, 14 data_file,
15 tokenizer, 15 tokenizer,
16 instance_prompt, 16 instance_identifier,
17 class_data_root=None, 17 class_identifier=None,
18 class_prompt=None,
19 size=512, 18 size=512,
20 repeats=100, 19 repeats=100,
21 interpolation="bicubic", 20 interpolation="bicubic",
22 identifier="*",
23 center_crop=False, 21 center_crop=False,
24 valid_set_size=None, 22 valid_set_size=None,
25 generator=None, 23 generator=None,
@@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule):
32 raise ValueError("data_file must be a file") 30 raise ValueError("data_file must be a file")
33 31
34 self.data_root = self.data_file.parent 32 self.data_root = self.data_file.parent
33 self.class_root = self.data_root.joinpath("db_cls")
34 self.class_root.mkdir(parents=True, exist_ok=True)
35
35 self.tokenizer = tokenizer 36 self.tokenizer = tokenizer
36 self.instance_prompt = instance_prompt 37 self.instance_identifier = instance_identifier
37 self.class_data_root = class_data_root 38 self.class_identifier = class_identifier
38 self.class_prompt = class_prompt
39 self.size = size 39 self.size = size
40 self.repeats = repeats 40 self.repeats = repeats
41 self.identifier = identifier
42 self.center_crop = center_crop 41 self.center_crop = center_crop
43 self.interpolation = interpolation 42 self.interpolation = interpolation
44 self.valid_set_size = valid_set_size 43 self.valid_set_size = valid_set_size
@@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule):
48 47
49 def prepare_data(self): 48 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 49 metadata = pd.read_csv(self.data_file)
51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 50 instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values]
51 class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values]
52 prompts = metadata['prompt'].values 52 prompts = metadata['prompt'].values
53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) 53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths)
54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) 54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths)
55 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] 55 self.data = [(i, c, p, n)
56 for i, c, p, n, s
57 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
58 if s != "x"]
56 59
57 def setup(self, stage=None): 60 def setup(self, stage=None):
58 valid_set_size = int(len(self.data_full) * 0.2) 61 valid_set_size = int(len(self.data) * 0.2)
59 if self.valid_set_size: 62 if self.valid_set_size:
60 valid_set_size = min(valid_set_size, self.valid_set_size) 63 valid_set_size = min(valid_set_size, self.valid_set_size)
61 train_set_size = len(self.data_full) - valid_set_size 64 valid_set_size = max(valid_set_size, 1)
62 65 train_set_size = len(self.data) - valid_set_size
63 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) 66
64 67 self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator)
65 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 68
66 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 69 train_dataset = CSVDataset(self.data_train, self.tokenizer,
67 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 70 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
68 center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) 71 size=self.size, interpolation=self.interpolation,
69 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 72 center_crop=self.center_crop, repeats=self.repeats)
70 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 73 val_dataset = CSVDataset(self.data_val, self.tokenizer,
71 center_crop=self.center_crop, batch_size=self.batch_size) 74 instance_identifier=self.instance_identifier,
72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 75 size=self.size, interpolation=self.interpolation,
76 center_crop=self.center_crop, repeats=self.repeats)
77 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True,
73 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 78 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 79 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True,
75 pin_memory=True, collate_fn=self.collate_fn) 80 pin_memory=True, collate_fn=self.collate_fn)
76 81
77 def train_dataloader(self): 82 def train_dataloader(self):
@@ -85,39 +90,23 @@ class CSVDataset(Dataset):
85 def __init__(self, 90 def __init__(self,
86 data, 91 data,
87 tokenizer, 92 tokenizer,
88 instance_prompt, 93 instance_identifier,
89 class_data_root=None, 94 class_identifier=None,
90 class_prompt=None,
91 size=512, 95 size=512,
92 repeats=1, 96 repeats=1,
93 interpolation="bicubic", 97 interpolation="bicubic",
94 identifier="*",
95 center_crop=False, 98 center_crop=False,
96 batch_size=1,
97 ): 99 ):
98 100
99 self.data = data 101 self.data = data
100 self.tokenizer = tokenizer 102 self.tokenizer = tokenizer
101 self.instance_prompt = instance_prompt 103 self.instance_identifier = instance_identifier
102 self.identifier = identifier 104 self.class_identifier = class_identifier
103 self.batch_size = batch_size
104 self.cache = {} 105 self.cache = {}
105 106
106 self.num_instance_images = len(self.data) 107 self.num_instance_images = len(self.data)
107 self._length = self.num_instance_images * repeats 108 self._length = self.num_instance_images * repeats
108 109
109 if class_data_root is not None:
110 self.class_data_root = Path(class_data_root)
111 self.class_data_root.mkdir(parents=True, exist_ok=True)
112
113 self.class_images = list(self.class_data_root.iterdir())
114 self.num_class_images = len(self.class_images)
115 self._length = max(self.num_class_images, self.num_instance_images)
116
117 self.class_prompt = class_prompt
118 else:
119 self.class_data_root = None
120
121 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, 110 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
122 "bilinear": transforms.InterpolationMode.BILINEAR, 111 "bilinear": transforms.InterpolationMode.BILINEAR,
123 "bicubic": transforms.InterpolationMode.BICUBIC, 112 "bicubic": transforms.InterpolationMode.BICUBIC,
@@ -134,46 +123,49 @@ class CSVDataset(Dataset):
134 ) 123 )
135 124
136 def __len__(self): 125 def __len__(self):
137 return math.ceil(self._length / self.batch_size) * self.batch_size 126 return self._length
138 127
139 def get_example(self, i): 128 def get_example(self, i):
140 image_path, prompt, nprompt = self.data[i % self.num_instance_images] 129 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images]
141 130
142 if image_path in self.cache: 131 if instance_image_path in self.cache:
143 return self.cache[image_path] 132 return self.cache[instance_image_path]
144 133
145 example = {} 134 example = {}
146 135
147 instance_image = Image.open(image_path) 136 example["prompts"] = prompt
137 example["nprompts"] = nprompt
138
139 instance_image = Image.open(instance_image_path)
148 if not instance_image.mode == "RGB": 140 if not instance_image.mode == "RGB":
149 instance_image = instance_image.convert("RGB") 141 instance_image = instance_image.convert("RGB")
150 142
151 prompt = prompt.format(self.identifier) 143 instance_prompt = prompt.format(self.instance_identifier)
152 144
153 example["prompts"] = prompt
154 example["nprompts"] = nprompt
155 example["instance_images"] = instance_image 145 example["instance_images"] = instance_image
156 example["instance_prompt_ids"] = self.tokenizer( 146 example["instance_prompt_ids"] = self.tokenizer(
157 self.instance_prompt, 147 instance_prompt,
158 padding="do_not_pad", 148 padding="do_not_pad",
159 truncation=True, 149 truncation=True,
160 max_length=self.tokenizer.model_max_length, 150 max_length=self.tokenizer.model_max_length,
161 ).input_ids 151 ).input_ids
162 152
163 if self.class_data_root: 153 if self.class_identifier:
164 class_image = Image.open(self.class_images[i % self.num_class_images]) 154 class_image = Image.open(class_image_path)
165 if not class_image.mode == "RGB": 155 if not class_image.mode == "RGB":
166 class_image = class_image.convert("RGB") 156 class_image = class_image.convert("RGB")
167 157
158 class_prompt = prompt.format(self.class_identifier)
159
168 example["class_images"] = class_image 160 example["class_images"] = class_image
169 example["class_prompt_ids"] = self.tokenizer( 161 example["class_prompt_ids"] = self.tokenizer(
170 self.class_prompt, 162 class_prompt,
171 padding="do_not_pad", 163 padding="do_not_pad",
172 truncation=True, 164 truncation=True,
173 max_length=self.tokenizer.model_max_length, 165 max_length=self.tokenizer.model_max_length,
174 ).input_ids 166 ).input_ids
175 167
176 self.cache[image_path] = example 168 self.cache[instance_image_path] = example
177 return example 169 return example
178 170
179 def __getitem__(self, i): 171 def __getitem__(self, i):
@@ -185,7 +177,7 @@ class CSVDataset(Dataset):
185 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
186 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
187 179
188 if self.class_data_root: 180 if self.class_identifier:
189 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
190 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
191 183
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
index 34f510d..b3a83ce 100644
--- a/data/dreambooth/prompt.py
+++ b/data/dreambooth/prompt.py
@@ -2,8 +2,9 @@ from torch.utils.data import Dataset
2 2
3 3
4class PromptDataset(Dataset): 4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples): 5 def __init__(self, prompt, nprompt, num_samples):
6 self.prompt = prompt 6 self.prompt = prompt
7 self.nprompt = nprompt
7 self.num_samples = num_samples 8 self.num_samples = num_samples
8 9
9 def __len__(self): 10 def __len__(self):
@@ -12,5 +13,6 @@ class PromptDataset(Dataset):
12 def __getitem__(self, index): 13 def __getitem__(self, index):
13 example = {} 14 example = {}
14 example["prompt"] = self.prompt 15 example["prompt"] = self.prompt
16 example["nprompt"] = self.nprompt
15 example["index"] = index 17 example["index"] = index
16 return example 18 return example