diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 112 | ||||
-rw-r--r-- | data/dreambooth/prompt.py | 4 |
2 files changed, 55 insertions, 61 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index c0b0067..4ebdc13 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
13 | batch_size, | 13 | batch_size, |
14 | data_file, | 14 | data_file, |
15 | tokenizer, | 15 | tokenizer, |
16 | instance_prompt, | 16 | instance_identifier, |
17 | class_data_root=None, | 17 | class_identifier=None, |
18 | class_prompt=None, | ||
19 | size=512, | 18 | size=512, |
20 | repeats=100, | 19 | repeats=100, |
21 | interpolation="bicubic", | 20 | interpolation="bicubic", |
22 | identifier="*", | ||
23 | center_crop=False, | 21 | center_crop=False, |
24 | valid_set_size=None, | 22 | valid_set_size=None, |
25 | generator=None, | 23 | generator=None, |
@@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule): | |||
32 | raise ValueError("data_file must be a file") | 30 | raise ValueError("data_file must be a file") |
33 | 31 | ||
34 | self.data_root = self.data_file.parent | 32 | self.data_root = self.data_file.parent |
33 | self.class_root = self.data_root.joinpath("db_cls") | ||
34 | self.class_root.mkdir(parents=True, exist_ok=True) | ||
35 | |||
35 | self.tokenizer = tokenizer | 36 | self.tokenizer = tokenizer |
36 | self.instance_prompt = instance_prompt | 37 | self.instance_identifier = instance_identifier |
37 | self.class_data_root = class_data_root | 38 | self.class_identifier = class_identifier |
38 | self.class_prompt = class_prompt | ||
39 | self.size = size | 39 | self.size = size |
40 | self.repeats = repeats | 40 | self.repeats = repeats |
41 | self.identifier = identifier | ||
42 | self.center_crop = center_crop | 41 | self.center_crop = center_crop |
43 | self.interpolation = interpolation | 42 | self.interpolation = interpolation |
44 | self.valid_set_size = valid_set_size | 43 | self.valid_set_size = valid_set_size |
@@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule): | |||
48 | 47 | ||
49 | def prepare_data(self): | 48 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 49 | metadata = pd.read_csv(self.data_file) |
51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 50 | instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] |
51 | class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] | ||
52 | prompts = metadata['prompt'].values | 52 | prompts = metadata['prompt'].values |
53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) |
54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) |
55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | 55 | self.data = [(i, c, p, n) |
56 | for i, c, p, n, s | ||
57 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
58 | if s != "x"] | ||
56 | 59 | ||
57 | def setup(self, stage=None): | 60 | def setup(self, stage=None): |
58 | valid_set_size = int(len(self.data_full) * 0.2) | 61 | valid_set_size = int(len(self.data) * 0.2) |
59 | if self.valid_set_size: | 62 | if self.valid_set_size: |
60 | valid_set_size = min(valid_set_size, self.valid_set_size) | 63 | valid_set_size = min(valid_set_size, self.valid_set_size) |
61 | train_set_size = len(self.data_full) - valid_set_size | 64 | valid_set_size = max(valid_set_size, 1) |
62 | 65 | train_set_size = len(self.data) - valid_set_size | |
63 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) | 66 | |
64 | 67 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) | |
65 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 68 | |
66 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 69 | train_dataset = CSVDataset(self.data_train, self.tokenizer, |
67 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 70 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
68 | center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) | 71 | size=self.size, interpolation=self.interpolation, |
69 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 72 | center_crop=self.center_crop, repeats=self.repeats) |
70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 73 | val_dataset = CSVDataset(self.data_val, self.tokenizer, |
71 | center_crop=self.center_crop, batch_size=self.batch_size) | 74 | instance_identifier=self.instance_identifier, |
72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 75 | size=self.size, interpolation=self.interpolation, |
76 | center_crop=self.center_crop, repeats=self.repeats) | ||
77 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | ||
73 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 78 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 79 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, |
75 | pin_memory=True, collate_fn=self.collate_fn) | 80 | pin_memory=True, collate_fn=self.collate_fn) |
76 | 81 | ||
77 | def train_dataloader(self): | 82 | def train_dataloader(self): |
@@ -85,39 +90,23 @@ class CSVDataset(Dataset): | |||
85 | def __init__(self, | 90 | def __init__(self, |
86 | data, | 91 | data, |
87 | tokenizer, | 92 | tokenizer, |
88 | instance_prompt, | 93 | instance_identifier, |
89 | class_data_root=None, | 94 | class_identifier=None, |
90 | class_prompt=None, | ||
91 | size=512, | 95 | size=512, |
92 | repeats=1, | 96 | repeats=1, |
93 | interpolation="bicubic", | 97 | interpolation="bicubic", |
94 | identifier="*", | ||
95 | center_crop=False, | 98 | center_crop=False, |
96 | batch_size=1, | ||
97 | ): | 99 | ): |
98 | 100 | ||
99 | self.data = data | 101 | self.data = data |
100 | self.tokenizer = tokenizer | 102 | self.tokenizer = tokenizer |
101 | self.instance_prompt = instance_prompt | 103 | self.instance_identifier = instance_identifier |
102 | self.identifier = identifier | 104 | self.class_identifier = class_identifier |
103 | self.batch_size = batch_size | ||
104 | self.cache = {} | 105 | self.cache = {} |
105 | 106 | ||
106 | self.num_instance_images = len(self.data) | 107 | self.num_instance_images = len(self.data) |
107 | self._length = self.num_instance_images * repeats | 108 | self._length = self.num_instance_images * repeats |
108 | 109 | ||
109 | if class_data_root is not None: | ||
110 | self.class_data_root = Path(class_data_root) | ||
111 | self.class_data_root.mkdir(parents=True, exist_ok=True) | ||
112 | |||
113 | self.class_images = list(self.class_data_root.iterdir()) | ||
114 | self.num_class_images = len(self.class_images) | ||
115 | self._length = max(self.num_class_images, self.num_instance_images) | ||
116 | |||
117 | self.class_prompt = class_prompt | ||
118 | else: | ||
119 | self.class_data_root = None | ||
120 | |||
121 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, | 110 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
122 | "bilinear": transforms.InterpolationMode.BILINEAR, | 111 | "bilinear": transforms.InterpolationMode.BILINEAR, |
123 | "bicubic": transforms.InterpolationMode.BICUBIC, | 112 | "bicubic": transforms.InterpolationMode.BICUBIC, |
@@ -134,46 +123,49 @@ class CSVDataset(Dataset): | |||
134 | ) | 123 | ) |
135 | 124 | ||
136 | def __len__(self): | 125 | def __len__(self): |
137 | return math.ceil(self._length / self.batch_size) * self.batch_size | 126 | return self._length |
138 | 127 | ||
139 | def get_example(self, i): | 128 | def get_example(self, i): |
140 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 129 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
141 | 130 | ||
142 | if image_path in self.cache: | 131 | if instance_image_path in self.cache: |
143 | return self.cache[image_path] | 132 | return self.cache[instance_image_path] |
144 | 133 | ||
145 | example = {} | 134 | example = {} |
146 | 135 | ||
147 | instance_image = Image.open(image_path) | 136 | example["prompts"] = prompt |
137 | example["nprompts"] = nprompt | ||
138 | |||
139 | instance_image = Image.open(instance_image_path) | ||
148 | if not instance_image.mode == "RGB": | 140 | if not instance_image.mode == "RGB": |
149 | instance_image = instance_image.convert("RGB") | 141 | instance_image = instance_image.convert("RGB") |
150 | 142 | ||
151 | prompt = prompt.format(self.identifier) | 143 | instance_prompt = prompt.format(self.instance_identifier) |
152 | 144 | ||
153 | example["prompts"] = prompt | ||
154 | example["nprompts"] = nprompt | ||
155 | example["instance_images"] = instance_image | 145 | example["instance_images"] = instance_image |
156 | example["instance_prompt_ids"] = self.tokenizer( | 146 | example["instance_prompt_ids"] = self.tokenizer( |
157 | self.instance_prompt, | 147 | instance_prompt, |
158 | padding="do_not_pad", | 148 | padding="do_not_pad", |
159 | truncation=True, | 149 | truncation=True, |
160 | max_length=self.tokenizer.model_max_length, | 150 | max_length=self.tokenizer.model_max_length, |
161 | ).input_ids | 151 | ).input_ids |
162 | 152 | ||
163 | if self.class_data_root: | 153 | if self.class_identifier: |
164 | class_image = Image.open(self.class_images[i % self.num_class_images]) | 154 | class_image = Image.open(class_image_path) |
165 | if not class_image.mode == "RGB": | 155 | if not class_image.mode == "RGB": |
166 | class_image = class_image.convert("RGB") | 156 | class_image = class_image.convert("RGB") |
167 | 157 | ||
158 | class_prompt = prompt.format(self.class_identifier) | ||
159 | |||
168 | example["class_images"] = class_image | 160 | example["class_images"] = class_image |
169 | example["class_prompt_ids"] = self.tokenizer( | 161 | example["class_prompt_ids"] = self.tokenizer( |
170 | self.class_prompt, | 162 | class_prompt, |
171 | padding="do_not_pad", | 163 | padding="do_not_pad", |
172 | truncation=True, | 164 | truncation=True, |
173 | max_length=self.tokenizer.model_max_length, | 165 | max_length=self.tokenizer.model_max_length, |
174 | ).input_ids | 166 | ).input_ids |
175 | 167 | ||
176 | self.cache[image_path] = example | 168 | self.cache[instance_image_path] = example |
177 | return example | 169 | return example |
178 | 170 | ||
179 | def __getitem__(self, i): | 171 | def __getitem__(self, i): |
@@ -185,7 +177,7 @@ class CSVDataset(Dataset): | |||
185 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
186 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
187 | 179 | ||
188 | if self.class_data_root: | 180 | if self.class_identifier: |
189 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
190 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
191 | 183 | ||
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py | |||
@@ -2,8 +2,9 @@ from torch.utils.data import Dataset | |||
2 | 2 | ||
3 | 3 | ||
4 | class PromptDataset(Dataset): | 4 | class PromptDataset(Dataset): |
5 | def __init__(self, prompt, num_samples): | 5 | def __init__(self, prompt, nprompt, num_samples): |
6 | self.prompt = prompt | 6 | self.prompt = prompt |
7 | self.nprompt = nprompt | ||
7 | self.num_samples = num_samples | 8 | self.num_samples = num_samples |
8 | 9 | ||
9 | def __len__(self): | 10 | def __len__(self): |
@@ -12,5 +13,6 @@ class PromptDataset(Dataset): | |||
12 | def __getitem__(self, index): | 13 | def __getitem__(self, index): |
13 | example = {} | 14 | example = {} |
14 | example["prompt"] = self.prompt | 15 | example["prompt"] = self.prompt |
16 | example["nprompt"] = self.nprompt | ||
15 | example["index"] = index | 17 | example["index"] = index |
16 | return example | 18 | return example |