From ae260060fddaaeccf2b68a1de51d5e780b099dfb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 10:54:40 +0200 Subject: Undo textual inversion dataset improvements --- data/textual_inversion/csv.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) (limited to 'data') diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 38ffb6f..0d1e96e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -80,19 +80,14 @@ class CSVDataset(Dataset): self.placeholder_token = placeholder_token + self.size = size + self.center_crop = center_crop self.interpolation = {"linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.cache = {} @@ -107,9 +102,9 @@ class CSVDataset(Dataset): example = {} image = Image.open(image_path) + if not image.mode == "RGB": image = image.convert("RGB") - image = self.image_transforms(image) text = text.format(self.placeholder_token) @@ -122,8 +117,24 @@ class CSVDataset(Dataset): return_tensors="pt", ).input_ids[0] + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), + resample=self.interpolation) + image = self.flip(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + example["key"] = "-".join([image_path, "-", str(flipped)]) - example["pixel_values"] = image + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) self.cache[image_path] = example return example -- cgit v1.2.3-70-g09d2