diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 10:54:40 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 10:54:40 +0200 |
| commit | ae260060fddaaeccf2b68a1de51d5e780b099dfb (patch) | |
| tree | 793655f089c4c7f0686f38bcddc5da1f0026d6ed /data | |
| parent | More cleanup (diff) | |
| download | textual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.tar.gz textual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.tar.bz2 textual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.zip | |
Undo textual inversion dataset improvements
Diffstat (limited to 'data')
| -rw-r--r-- | data/textual_inversion/csv.py | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 38ffb6f..0d1e96e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -80,19 +80,14 @@ class CSVDataset(Dataset): | |||
| 80 | 80 | ||
| 81 | self.placeholder_token = placeholder_token | 81 | self.placeholder_token = placeholder_token |
| 82 | 82 | ||
| 83 | self.size = size | ||
| 84 | self.center_crop = center_crop | ||
| 83 | self.interpolation = {"linear": PIL.Image.LINEAR, | 85 | self.interpolation = {"linear": PIL.Image.LINEAR, |
| 84 | "bilinear": PIL.Image.BILINEAR, | 86 | "bilinear": PIL.Image.BILINEAR, |
| 85 | "bicubic": PIL.Image.BICUBIC, | 87 | "bicubic": PIL.Image.BICUBIC, |
| 86 | "lanczos": PIL.Image.LANCZOS, | 88 | "lanczos": PIL.Image.LANCZOS, |
| 87 | }[interpolation] | 89 | }[interpolation] |
| 88 | self.image_transforms = transforms.Compose( | 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) |
| 89 | [ | ||
| 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
| 91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
| 92 | transforms.ToTensor(), | ||
| 93 | transforms.Normalize([0.5], [0.5]), | ||
| 94 | ] | ||
| 95 | ) | ||
| 96 | 91 | ||
| 97 | self.cache = {} | 92 | self.cache = {} |
| 98 | 93 | ||
| @@ -107,9 +102,9 @@ class CSVDataset(Dataset): | |||
| 107 | 102 | ||
| 108 | example = {} | 103 | example = {} |
| 109 | image = Image.open(image_path) | 104 | image = Image.open(image_path) |
| 105 | |||
| 110 | if not image.mode == "RGB": | 106 | if not image.mode == "RGB": |
| 111 | image = image.convert("RGB") | 107 | image = image.convert("RGB") |
| 112 | image = self.image_transforms(image) | ||
| 113 | 108 | ||
| 114 | text = text.format(self.placeholder_token) | 109 | text = text.format(self.placeholder_token) |
| 115 | 110 | ||
| @@ -122,8 +117,24 @@ class CSVDataset(Dataset): | |||
| 122 | return_tensors="pt", | 117 | return_tensors="pt", |
| 123 | ).input_ids[0] | 118 | ).input_ids[0] |
| 124 | 119 | ||
| 120 | # default to score-sde preprocessing | ||
| 121 | img = np.array(image).astype(np.uint8) | ||
| 122 | |||
| 123 | if self.center_crop: | ||
| 124 | crop = min(img.shape[0], img.shape[1]) | ||
| 125 | h, w, = img.shape[0], img.shape[1] | ||
| 126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
| 127 | (w - crop) // 2:(w + crop) // 2] | ||
| 128 | |||
| 129 | image = Image.fromarray(img) | ||
| 130 | image = image.resize((self.size, self.size), | ||
| 131 | resample=self.interpolation) | ||
| 132 | image = self.flip(image) | ||
| 133 | image = np.array(image).astype(np.uint8) | ||
| 134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
| 135 | |||
| 125 | example["key"] = "-".join([image_path, "-", str(flipped)]) | 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) |
| 126 | example["pixel_values"] = image | 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) |
| 127 | 138 | ||
| 128 | self.cache[image_path] = example | 139 | self.cache[image_path] = example |
| 129 | return example | 140 | return example |
