diff options
Diffstat (limited to 'data/textual_inversion')
-rw-r--r-- | data/textual_inversion/csv.py | 98 |
1 files changed, 46 insertions, 52 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -1,11 +1,10 @@ | |||
1 | import os | 1 | import os |
2 | import numpy as np | 2 | import numpy as np |
3 | import pandas as pd | 3 | import pandas as pd |
4 | import random | 4 | from pathlib import Path |
5 | import PIL | 5 | import math |
6 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
7 | from PIL import Image | 7 | from PIL import Image |
8 | import torch | ||
9 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
10 | from torchvision import transforms | 9 | from torchvision import transforms |
11 | 10 | ||
@@ -13,29 +12,32 @@ from torchvision import transforms | |||
13 | class CSVDataModule(pl.LightningDataModule): | 12 | class CSVDataModule(pl.LightningDataModule): |
14 | def __init__(self, | 13 | def __init__(self, |
15 | batch_size, | 14 | batch_size, |
16 | data_root, | 15 | data_file, |
17 | tokenizer, | 16 | tokenizer, |
18 | size=512, | 17 | size=512, |
19 | repeats=100, | 18 | repeats=100, |
20 | interpolation="bicubic", | 19 | interpolation="bicubic", |
21 | placeholder_token="*", | 20 | placeholder_token="*", |
22 | flip_p=0.5, | ||
23 | center_crop=False): | 21 | center_crop=False): |
24 | super().__init__() | 22 | super().__init__() |
25 | 23 | ||
26 | self.data_root = data_root | 24 | self.data_file = Path(data_file) |
25 | |||
26 | if not self.data_file.is_file(): | ||
27 | raise ValueError("data_file must be a file") | ||
28 | |||
29 | self.data_root = self.data_file.parent | ||
27 | self.tokenizer = tokenizer | 30 | self.tokenizer = tokenizer |
28 | self.size = size | 31 | self.size = size |
29 | self.repeats = repeats | 32 | self.repeats = repeats |
30 | self.placeholder_token = placeholder_token | 33 | self.placeholder_token = placeholder_token |
31 | self.center_crop = center_crop | 34 | self.center_crop = center_crop |
32 | self.flip_p = flip_p | ||
33 | self.interpolation = interpolation | 35 | self.interpolation = interpolation |
34 | 36 | ||
35 | self.batch_size = batch_size | 37 | self.batch_size = batch_size |
36 | 38 | ||
37 | def prepare_data(self): | 39 | def prepare_data(self): |
38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 40 | metadata = pd.read_csv(self.data_file) |
39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 41 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
40 | captions = [caption for caption in metadata['caption'].values] | 42 | captions = [caption for caption in metadata['caption'].values] |
41 | skips = [skip for skip in metadata['skip'].values] | 43 | skips = [skip for skip in metadata['skip'].values] |
@@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
48 | 50 | ||
49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 52 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 53 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 54 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 55 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) |
54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 56 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) |
55 | 57 | ||
@@ -67,48 +69,54 @@ class CSVDataset(Dataset): | |||
67 | size=512, | 69 | size=512, |
68 | repeats=1, | 70 | repeats=1, |
69 | interpolation="bicubic", | 71 | interpolation="bicubic", |
70 | flip_p=0.5, | ||
71 | placeholder_token="*", | 72 | placeholder_token="*", |
72 | center_crop=False, | 73 | center_crop=False, |
74 | batch_size=1, | ||
73 | ): | 75 | ): |
74 | 76 | ||
75 | self.data = data | 77 | self.data = data |
76 | self.tokenizer = tokenizer | 78 | self.tokenizer = tokenizer |
77 | |||
78 | self.num_images = len(self.data) | ||
79 | self._length = self.num_images * repeats | ||
80 | |||
81 | self.placeholder_token = placeholder_token | 79 | self.placeholder_token = placeholder_token |
80 | self.batch_size = batch_size | ||
81 | self.cache = {} | ||
82 | 82 | ||
83 | self.size = size | 83 | self.num_instance_images = len(self.data) |
84 | self.center_crop = center_crop | 84 | self._length = self.num_instance_images * repeats |
85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
86 | "bilinear": PIL.Image.BILINEAR, | ||
87 | "bicubic": PIL.Image.BICUBIC, | ||
88 | "lanczos": PIL.Image.LANCZOS, | ||
89 | }[interpolation] | ||
90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
91 | 85 | ||
92 | self.cache = {} | 86 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
87 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
88 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
89 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
90 | }[interpolation] | ||
91 | self.image_transforms = transforms.Compose( | ||
92 | [ | ||
93 | transforms.Resize(size, interpolation=self.interpolation), | ||
94 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
95 | transforms.RandomHorizontalFlip(), | ||
96 | transforms.ToTensor(), | ||
97 | transforms.Normalize([0.5], [0.5]), | ||
98 | ] | ||
99 | ) | ||
93 | 100 | ||
94 | def __len__(self): | 101 | def __len__(self): |
95 | return self._length | 102 | return math.ceil(self._length / self.batch_size) * self.batch_size |
96 | 103 | ||
97 | def get_example(self, i, flipped): | 104 | def get_example(self, i): |
98 | image_path, text = self.data[i % self.num_images] | 105 | image_path, text = self.data[i % self.num_instance_images] |
99 | 106 | ||
100 | if image_path in self.cache: | 107 | if image_path in self.cache: |
101 | return self.cache[image_path] | 108 | return self.cache[image_path] |
102 | 109 | ||
103 | example = {} | 110 | example = {} |
104 | image = Image.open(image_path) | ||
105 | 111 | ||
106 | if not image.mode == "RGB": | 112 | instance_image = Image.open(image_path) |
107 | image = image.convert("RGB") | 113 | if not instance_image.mode == "RGB": |
114 | instance_image = instance_image.convert("RGB") | ||
108 | 115 | ||
109 | text = text.format(self.placeholder_token) | 116 | text = text.format(self.placeholder_token) |
110 | 117 | ||
111 | example["prompt"] = text | 118 | example["prompts"] = text |
119 | example["pixel_values"] = instance_image | ||
112 | example["input_ids"] = self.tokenizer( | 120 | example["input_ids"] = self.tokenizer( |
113 | text, | 121 | text, |
114 | padding="max_length", | 122 | padding="max_length", |
@@ -117,29 +125,15 @@ class CSVDataset(Dataset): | |||
117 | return_tensors="pt", | 125 | return_tensors="pt", |
118 | ).input_ids[0] | 126 | ).input_ids[0] |
119 | 127 | ||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
138 | |||
139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
140 | return example | 129 | return example |
141 | 130 | ||
142 | def __getitem__(self, i): | 131 | def __getitem__(self, i): |
143 | flipped = random.choice([False, True]) | 132 | example = {} |
144 | example = self.get_example(i, flipped) | 133 | unprocessed_example = self.get_example(i) |
134 | |||
135 | example["prompts"] = unprocessed_example["prompts"] | ||
136 | example["input_ids"] = unprocessed_example["input_ids"] | ||
137 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | ||
138 | |||
145 | return example | 139 | return example |