diff options
-rw-r--r-- | data/textual_inversion/csv.py (renamed from data.py) | 31 | ||||
-rw-r--r-- | textual_inversion.py (renamed from main.py) | 5 |
2 files changed, 11 insertions, 25 deletions
diff --git a/data.py b/data/textual_inversion/csv.py index 0d1e96e..38ffb6f 100644 --- a/data.py +++ b/data/textual_inversion/csv.py | |||
@@ -80,14 +80,19 @@ class CSVDataset(Dataset): | |||
80 | 80 | ||
81 | self.placeholder_token = placeholder_token | 81 | self.placeholder_token = placeholder_token |
82 | 82 | ||
83 | self.size = size | ||
84 | self.center_crop = center_crop | ||
85 | self.interpolation = {"linear": PIL.Image.LINEAR, | 83 | self.interpolation = {"linear": PIL.Image.LINEAR, |
86 | "bilinear": PIL.Image.BILINEAR, | 84 | "bilinear": PIL.Image.BILINEAR, |
87 | "bicubic": PIL.Image.BICUBIC, | 85 | "bicubic": PIL.Image.BICUBIC, |
88 | "lanczos": PIL.Image.LANCZOS, | 86 | "lanczos": PIL.Image.LANCZOS, |
89 | }[interpolation] | 87 | }[interpolation] |
90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | 88 | self.image_transforms = transforms.Compose( |
89 | [ | ||
90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
92 | transforms.ToTensor(), | ||
93 | transforms.Normalize([0.5], [0.5]), | ||
94 | ] | ||
95 | ) | ||
91 | 96 | ||
92 | self.cache = {} | 97 | self.cache = {} |
93 | 98 | ||
@@ -102,9 +107,9 @@ class CSVDataset(Dataset): | |||
102 | 107 | ||
103 | example = {} | 108 | example = {} |
104 | image = Image.open(image_path) | 109 | image = Image.open(image_path) |
105 | |||
106 | if not image.mode == "RGB": | 110 | if not image.mode == "RGB": |
107 | image = image.convert("RGB") | 111 | image = image.convert("RGB") |
112 | image = self.image_transforms(image) | ||
108 | 113 | ||
109 | text = text.format(self.placeholder_token) | 114 | text = text.format(self.placeholder_token) |
110 | 115 | ||
@@ -117,24 +122,8 @@ class CSVDataset(Dataset): | |||
117 | return_tensors="pt", | 122 | return_tensors="pt", |
118 | ).input_ids[0] | 123 | ).input_ids[0] |
119 | 124 | ||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | 125 | example["key"] = "-".join([image_path, "-", str(flipped)]) |
137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | 126 | example["pixel_values"] = image |
138 | 127 | ||
139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
140 | return example | 129 | return example |
diff --git a/main.py b/textual_inversion.py index 51b64c1..aa8e744 100644 --- a/main.py +++ b/textual_inversion.py | |||
@@ -2,10 +2,7 @@ import argparse | |||
2 | import itertools | 2 | import itertools |
3 | import math | 3 | import math |
4 | import os | 4 | import os |
5 | import random | ||
6 | import datetime | 5 | import datetime |
7 | from pathlib import Path | ||
8 | from typing import Optional | ||
9 | 6 | ||
10 | import numpy as np | 7 | import numpy as np |
11 | import torch | 8 | import torch |
@@ -25,7 +22,7 @@ from slugify import slugify | |||
25 | import json | 22 | import json |
26 | import os | 23 | import os |
27 | 24 | ||
28 | from data import CSVDataModule | 25 | from data.textual_inversion.csv import CSVDataModule |
29 | 26 | ||
30 | logger = get_logger(__name__) | 27 | logger = get_logger(__name__) |
31 | 28 | ||