summaryrefslogtreecommitdiffstats
path: root/data/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'data/textual_inversion')
-rw-r--r--data/textual_inversion/csv.py98
1 files changed, 46 insertions, 52 deletions
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 0d1e96e..f306c7a 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -1,11 +1,10 @@
1import os 1import os
2import numpy as np 2import numpy as np
3import pandas as pd 3import pandas as pd
4import random 4from pathlib import Path
5import PIL 5import math
6import pytorch_lightning as pl 6import pytorch_lightning as pl
7from PIL import Image 7from PIL import Image
8import torch
9from torch.utils.data import Dataset, DataLoader, random_split 8from torch.utils.data import Dataset, DataLoader, random_split
10from torchvision import transforms 9from torchvision import transforms
11 10
@@ -13,29 +12,32 @@ from torchvision import transforms
13class CSVDataModule(pl.LightningDataModule): 12class CSVDataModule(pl.LightningDataModule):
14 def __init__(self, 13 def __init__(self,
15 batch_size, 14 batch_size,
16 data_root, 15 data_file,
17 tokenizer, 16 tokenizer,
18 size=512, 17 size=512,
19 repeats=100, 18 repeats=100,
20 interpolation="bicubic", 19 interpolation="bicubic",
21 placeholder_token="*", 20 placeholder_token="*",
22 flip_p=0.5,
23 center_crop=False): 21 center_crop=False):
24 super().__init__() 22 super().__init__()
25 23
26 self.data_root = data_root 24 self.data_file = Path(data_file)
25
26 if not self.data_file.is_file():
27 raise ValueError("data_file must be a file")
28
29 self.data_root = self.data_file.parent
27 self.tokenizer = tokenizer 30 self.tokenizer = tokenizer
28 self.size = size 31 self.size = size
29 self.repeats = repeats 32 self.repeats = repeats
30 self.placeholder_token = placeholder_token 33 self.placeholder_token = placeholder_token
31 self.center_crop = center_crop 34 self.center_crop = center_crop
32 self.flip_p = flip_p
33 self.interpolation = interpolation 35 self.interpolation = interpolation
34 36
35 self.batch_size = batch_size 37 self.batch_size = batch_size
36 38
37 def prepare_data(self): 39 def prepare_data(self):
38 metadata = pd.read_csv(f'{self.data_root}/list.csv') 40 metadata = pd.read_csv(self.data_file)
39 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 41 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
40 captions = [caption for caption in metadata['caption'].values] 42 captions = [caption for caption in metadata['caption'].values]
41 skips = [skip for skip in metadata['skip'].values] 43 skips = [skip for skip in metadata['skip'].values]
@@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule):
47 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 49 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
48 50
49 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 51 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
50 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) 52 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
51 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 53 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
52 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) 54 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
53 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 55 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
54 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) 56 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
55 57
@@ -67,48 +69,54 @@ class CSVDataset(Dataset):
67 size=512, 69 size=512,
68 repeats=1, 70 repeats=1,
69 interpolation="bicubic", 71 interpolation="bicubic",
70 flip_p=0.5,
71 placeholder_token="*", 72 placeholder_token="*",
72 center_crop=False, 73 center_crop=False,
74 batch_size=1,
73 ): 75 ):
74 76
75 self.data = data 77 self.data = data
76 self.tokenizer = tokenizer 78 self.tokenizer = tokenizer
77
78 self.num_images = len(self.data)
79 self._length = self.num_images * repeats
80
81 self.placeholder_token = placeholder_token 79 self.placeholder_token = placeholder_token
80 self.batch_size = batch_size
81 self.cache = {}
82 82
83 self.size = size 83 self.num_instance_images = len(self.data)
84 self.center_crop = center_crop 84 self._length = self.num_instance_images * repeats
85 self.interpolation = {"linear": PIL.Image.LINEAR,
86 "bilinear": PIL.Image.BILINEAR,
87 "bicubic": PIL.Image.BICUBIC,
88 "lanczos": PIL.Image.LANCZOS,
89 }[interpolation]
90 self.flip = transforms.RandomHorizontalFlip(p=flip_p)
91 85
92 self.cache = {} 86 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
87 "bilinear": transforms.InterpolationMode.BILINEAR,
88 "bicubic": transforms.InterpolationMode.BICUBIC,
89 "lanczos": transforms.InterpolationMode.LANCZOS,
90 }[interpolation]
91 self.image_transforms = transforms.Compose(
92 [
93 transforms.Resize(size, interpolation=self.interpolation),
94 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
95 transforms.RandomHorizontalFlip(),
96 transforms.ToTensor(),
97 transforms.Normalize([0.5], [0.5]),
98 ]
99 )
93 100
94 def __len__(self): 101 def __len__(self):
95 return self._length 102 return math.ceil(self._length / self.batch_size) * self.batch_size
96 103
97 def get_example(self, i, flipped): 104 def get_example(self, i):
98 image_path, text = self.data[i % self.num_images] 105 image_path, text = self.data[i % self.num_instance_images]
99 106
100 if image_path in self.cache: 107 if image_path in self.cache:
101 return self.cache[image_path] 108 return self.cache[image_path]
102 109
103 example = {} 110 example = {}
104 image = Image.open(image_path)
105 111
106 if not image.mode == "RGB": 112 instance_image = Image.open(image_path)
107 image = image.convert("RGB") 113 if not instance_image.mode == "RGB":
114 instance_image = instance_image.convert("RGB")
108 115
109 text = text.format(self.placeholder_token) 116 text = text.format(self.placeholder_token)
110 117
111 example["prompt"] = text 118 example["prompts"] = text
119 example["pixel_values"] = instance_image
112 example["input_ids"] = self.tokenizer( 120 example["input_ids"] = self.tokenizer(
113 text, 121 text,
114 padding="max_length", 122 padding="max_length",
@@ -117,29 +125,15 @@ class CSVDataset(Dataset):
117 return_tensors="pt", 125 return_tensors="pt",
118 ).input_ids[0] 126 ).input_ids[0]
119 127
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
136 example["key"] = "-".join([image_path, "-", str(flipped)])
137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
138
139 self.cache[image_path] = example 128 self.cache[image_path] = example
140 return example 129 return example
141 130
142 def __getitem__(self, i): 131 def __getitem__(self, i):
143 flipped = random.choice([False, True]) 132 example = {}
144 example = self.get_example(i, flipped) 133 unprocessed_example = self.get_example(i)
134
135 example["prompts"] = unprocessed_example["prompts"]
136 example["input_ids"] = unprocessed_example["input_ids"]
137 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
138
145 return example 139 return example