diff options
Diffstat (limited to 'data/textual_inversion')
-rw-r--r-- | data/textual_inversion/csv.py | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 38ffb6f..0d1e96e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -80,19 +80,14 @@ class CSVDataset(Dataset): | |||
80 | 80 | ||
81 | self.placeholder_token = placeholder_token | 81 | self.placeholder_token = placeholder_token |
82 | 82 | ||
83 | self.size = size | ||
84 | self.center_crop = center_crop | ||
83 | self.interpolation = {"linear": PIL.Image.LINEAR, | 85 | self.interpolation = {"linear": PIL.Image.LINEAR, |
84 | "bilinear": PIL.Image.BILINEAR, | 86 | "bilinear": PIL.Image.BILINEAR, |
85 | "bicubic": PIL.Image.BICUBIC, | 87 | "bicubic": PIL.Image.BICUBIC, |
86 | "lanczos": PIL.Image.LANCZOS, | 88 | "lanczos": PIL.Image.LANCZOS, |
87 | }[interpolation] | 89 | }[interpolation] |
88 | self.image_transforms = transforms.Compose( | 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) |
89 | [ | ||
90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
92 | transforms.ToTensor(), | ||
93 | transforms.Normalize([0.5], [0.5]), | ||
94 | ] | ||
95 | ) | ||
96 | 91 | ||
97 | self.cache = {} | 92 | self.cache = {} |
98 | 93 | ||
@@ -107,9 +102,9 @@ class CSVDataset(Dataset): | |||
107 | 102 | ||
108 | example = {} | 103 | example = {} |
109 | image = Image.open(image_path) | 104 | image = Image.open(image_path) |
105 | |||
110 | if not image.mode == "RGB": | 106 | if not image.mode == "RGB": |
111 | image = image.convert("RGB") | 107 | image = image.convert("RGB") |
112 | image = self.image_transforms(image) | ||
113 | 108 | ||
114 | text = text.format(self.placeholder_token) | 109 | text = text.format(self.placeholder_token) |
115 | 110 | ||
@@ -122,8 +117,24 @@ class CSVDataset(Dataset): | |||
122 | return_tensors="pt", | 117 | return_tensors="pt", |
123 | ).input_ids[0] | 118 | ).input_ids[0] |
124 | 119 | ||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
125 | example["key"] = "-".join([image_path, "-", str(flipped)]) | 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) |
126 | example["pixel_values"] = image | 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) |
127 | 138 | ||
128 | self.cache[image_path] = example | 139 | self.cache[image_path] = example |
129 | return example | 140 | return example |