diff options
| -rw-r--r-- | data/textual_inversion/csv.py (renamed from data.py) | 31 | ||||
| -rw-r--r-- | textual_inversion.py (renamed from main.py) | 5 |
2 files changed, 11 insertions, 25 deletions
diff --git a/data.py b/data/textual_inversion/csv.py index 0d1e96e..38ffb6f 100644 --- a/data.py +++ b/data/textual_inversion/csv.py | |||
| @@ -80,14 +80,19 @@ class CSVDataset(Dataset): | |||
| 80 | 80 | ||
| 81 | self.placeholder_token = placeholder_token | 81 | self.placeholder_token = placeholder_token |
| 82 | 82 | ||
| 83 | self.size = size | ||
| 84 | self.center_crop = center_crop | ||
| 85 | self.interpolation = {"linear": PIL.Image.LINEAR, | 83 | self.interpolation = {"linear": PIL.Image.LINEAR, |
| 86 | "bilinear": PIL.Image.BILINEAR, | 84 | "bilinear": PIL.Image.BILINEAR, |
| 87 | "bicubic": PIL.Image.BICUBIC, | 85 | "bicubic": PIL.Image.BICUBIC, |
| 88 | "lanczos": PIL.Image.LANCZOS, | 86 | "lanczos": PIL.Image.LANCZOS, |
| 89 | }[interpolation] | 87 | }[interpolation] |
| 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | 88 | self.image_transforms = transforms.Compose( |
| 89 | [ | ||
| 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
| 91 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
| 92 | transforms.ToTensor(), | ||
| 93 | transforms.Normalize([0.5], [0.5]), | ||
| 94 | ] | ||
| 95 | ) | ||
| 91 | 96 | ||
| 92 | self.cache = {} | 97 | self.cache = {} |
| 93 | 98 | ||
| @@ -102,9 +107,9 @@ class CSVDataset(Dataset): | |||
| 102 | 107 | ||
| 103 | example = {} | 108 | example = {} |
| 104 | image = Image.open(image_path) | 109 | image = Image.open(image_path) |
| 105 | |||
| 106 | if not image.mode == "RGB": | 110 | if not image.mode == "RGB": |
| 107 | image = image.convert("RGB") | 111 | image = image.convert("RGB") |
| 112 | image = self.image_transforms(image) | ||
| 108 | 113 | ||
| 109 | text = text.format(self.placeholder_token) | 114 | text = text.format(self.placeholder_token) |
| 110 | 115 | ||
| @@ -117,24 +122,8 @@ class CSVDataset(Dataset): | |||
| 117 | return_tensors="pt", | 122 | return_tensors="pt", |
| 118 | ).input_ids[0] | 123 | ).input_ids[0] |
| 119 | 124 | ||
| 120 | # default to score-sde preprocessing | ||
| 121 | img = np.array(image).astype(np.uint8) | ||
| 122 | |||
| 123 | if self.center_crop: | ||
| 124 | crop = min(img.shape[0], img.shape[1]) | ||
| 125 | h, w, = img.shape[0], img.shape[1] | ||
| 126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
| 127 | (w - crop) // 2:(w + crop) // 2] | ||
| 128 | |||
| 129 | image = Image.fromarray(img) | ||
| 130 | image = image.resize((self.size, self.size), | ||
| 131 | resample=self.interpolation) | ||
| 132 | image = self.flip(image) | ||
| 133 | image = np.array(image).astype(np.uint8) | ||
| 134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
| 135 | |||
| 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | 125 | example["key"] = "-".join([image_path, "-", str(flipped)]) |
| 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | 126 | example["pixel_values"] = image |
| 138 | 127 | ||
| 139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
| 140 | return example | 129 | return example |
diff --git a/main.py b/textual_inversion.py index 51b64c1..aa8e744 100644 --- a/main.py +++ b/textual_inversion.py | |||
| @@ -2,10 +2,7 @@ import argparse | |||
| 2 | import itertools | 2 | import itertools |
| 3 | import math | 3 | import math |
| 4 | import os | 4 | import os |
| 5 | import random | ||
| 6 | import datetime | 5 | import datetime |
| 7 | from pathlib import Path | ||
| 8 | from typing import Optional | ||
| 9 | 6 | ||
| 10 | import numpy as np | 7 | import numpy as np |
| 11 | import torch | 8 | import torch |
| @@ -25,7 +22,7 @@ from slugify import slugify | |||
| 25 | import json | 22 | import json |
| 26 | import os | 23 | import os |
| 27 | 24 | ||
| 28 | from data import CSVDataModule | 25 | from data.textual_inversion.csv import CSVDataModule |
| 29 | 26 | ||
| 30 | logger = get_logger(__name__) | 27 | logger = get_logger(__name__) |
| 31 | 28 | ||
