summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/textual_inversion/csv.py (renamed from data.py)31
-rw-r--r--textual_inversion.py (renamed from main.py)5
2 files changed, 11 insertions, 25 deletions
diff --git a/data.py b/data/textual_inversion/csv.py
index 0d1e96e..38ffb6f 100644
--- a/data.py
+++ b/data/textual_inversion/csv.py
@@ -80,14 +80,19 @@ class CSVDataset(Dataset):
80 80
81 self.placeholder_token = placeholder_token 81 self.placeholder_token = placeholder_token
82 82
83 self.size = size
84 self.center_crop = center_crop
85 self.interpolation = {"linear": PIL.Image.LINEAR, 83 self.interpolation = {"linear": PIL.Image.LINEAR,
86 "bilinear": PIL.Image.BILINEAR, 84 "bilinear": PIL.Image.BILINEAR,
87 "bicubic": PIL.Image.BICUBIC, 85 "bicubic": PIL.Image.BICUBIC,
88 "lanczos": PIL.Image.LANCZOS, 86 "lanczos": PIL.Image.LANCZOS,
89 }[interpolation] 87 }[interpolation]
90 self.flip = transforms.RandomHorizontalFlip(p=flip_p) 88 self.image_transforms = transforms.Compose(
89 [
90 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
91 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
92 transforms.ToTensor(),
93 transforms.Normalize([0.5], [0.5]),
94 ]
95 )
91 96
92 self.cache = {} 97 self.cache = {}
93 98
@@ -102,9 +107,9 @@ class CSVDataset(Dataset):
102 107
103 example = {} 108 example = {}
104 image = Image.open(image_path) 109 image = Image.open(image_path)
105
106 if not image.mode == "RGB": 110 if not image.mode == "RGB":
107 image = image.convert("RGB") 111 image = image.convert("RGB")
112 image = self.image_transforms(image)
108 113
109 text = text.format(self.placeholder_token) 114 text = text.format(self.placeholder_token)
110 115
@@ -117,24 +122,8 @@ class CSVDataset(Dataset):
117 return_tensors="pt", 122 return_tensors="pt",
118 ).input_ids[0] 123 ).input_ids[0]
119 124
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
136 example["key"] = "-".join([image_path, "-", str(flipped)]) 125 example["key"] = "-".join([image_path, "-", str(flipped)])
137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 126 example["pixel_values"] = image
138 127
139 self.cache[image_path] = example 128 self.cache[image_path] = example
140 return example 129 return example
diff --git a/main.py b/textual_inversion.py
index 51b64c1..aa8e744 100644
--- a/main.py
+++ b/textual_inversion.py
@@ -2,10 +2,7 @@ import argparse
2import itertools 2import itertools
3import math 3import math
4import os 4import os
5import random
6import datetime 5import datetime
7from pathlib import Path
8from typing import Optional
9 6
10import numpy as np 7import numpy as np
11import torch 8import torch
@@ -25,7 +22,7 @@ from slugify import slugify
25import json 22import json
26import os 23import os
27 24
28from data import CSVDataModule 25from data.textual_inversion.csv import CSVDataModule
29 26
30logger = get_logger(__name__) 27logger = get_logger(__name__)
31 28