summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 10:54:40 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 10:54:40 +0200
commitae260060fddaaeccf2b68a1de51d5e780b099dfb (patch)
tree793655f089c4c7f0686f38bcddc5da1f0026d6ed /data
parentMore cleanup (diff)
downloadtextual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.tar.gz
textual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.tar.bz2
textual-inversion-diff-ae260060fddaaeccf2b68a1de51d5e780b099dfb.zip
Undo textual inversion dataset improvements
Diffstat (limited to 'data')
-rw-r--r--data/textual_inversion/csv.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 38ffb6f..0d1e96e 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -80,19 +80,14 @@ class CSVDataset(Dataset):
80 80
81 self.placeholder_token = placeholder_token 81 self.placeholder_token = placeholder_token
82 82
83 self.size = size
84 self.center_crop = center_crop
83 self.interpolation = {"linear": PIL.Image.LINEAR, 85 self.interpolation = {"linear": PIL.Image.LINEAR,
84 "bilinear": PIL.Image.BILINEAR, 86 "bilinear": PIL.Image.BILINEAR,
85 "bicubic": PIL.Image.BICUBIC, 87 "bicubic": PIL.Image.BICUBIC,
86 "lanczos": PIL.Image.LANCZOS, 88 "lanczos": PIL.Image.LANCZOS,
87 }[interpolation] 89 }[interpolation]
88 self.image_transforms = transforms.Compose( 90 self.flip = transforms.RandomHorizontalFlip(p=flip_p)
89 [
90 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
91 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
92 transforms.ToTensor(),
93 transforms.Normalize([0.5], [0.5]),
94 ]
95 )
96 91
97 self.cache = {} 92 self.cache = {}
98 93
@@ -107,9 +102,9 @@ class CSVDataset(Dataset):
107 102
108 example = {} 103 example = {}
109 image = Image.open(image_path) 104 image = Image.open(image_path)
105
110 if not image.mode == "RGB": 106 if not image.mode == "RGB":
111 image = image.convert("RGB") 107 image = image.convert("RGB")
112 image = self.image_transforms(image)
113 108
114 text = text.format(self.placeholder_token) 109 text = text.format(self.placeholder_token)
115 110
@@ -122,8 +117,24 @@ class CSVDataset(Dataset):
122 return_tensors="pt", 117 return_tensors="pt",
123 ).input_ids[0] 118 ).input_ids[0]
124 119
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
125 example["key"] = "-".join([image_path, "-", str(flipped)]) 136 example["key"] = "-".join([image_path, "-", str(flipped)])
126 example["pixel_values"] = image 137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
127 138
128 self.cache[image_path] = example 139 self.cache[image_path] = example
129 return example 140 return example