summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py (renamed from data/dreambooth/csv.py)0
-rw-r--r--data/textual_inversion/csv.py150
2 files changed, 0 insertions, 150 deletions
diff --git a/data/dreambooth/csv.py b/data/csv.py
index abd329d..abd329d 100644
--- a/data/dreambooth/csv.py
+++ b/data/csv.py
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
deleted file mode 100644
index 4c5e27e..0000000
--- a/data/textual_inversion/csv.py
+++ /dev/null
@@ -1,150 +0,0 @@
1import os
2import numpy as np
3import pandas as pd
4from pathlib import Path
5import math
6import pytorch_lightning as pl
7from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split
9from torchvision import transforms
10
11
12class CSVDataModule(pl.LightningDataModule):
13 def __init__(self,
14 batch_size,
15 data_file,
16 tokenizer,
17 size=512,
18 repeats=100,
19 interpolation="bicubic",
20 placeholder_token="*",
21 center_crop=False,
22 valid_set_size=None,
23 generator=None):
24 super().__init__()
25
26 self.data_file = Path(data_file)
27
28 if not self.data_file.is_file():
29 raise ValueError("data_file must be a file")
30
31 self.data_root = self.data_file.parent
32 self.tokenizer = tokenizer
33 self.size = size
34 self.repeats = repeats
35 self.placeholder_token = placeholder_token
36 self.center_crop = center_crop
37 self.interpolation = interpolation
38 self.valid_set_size = valid_set_size
39 self.generator = generator
40
41 self.batch_size = batch_size
42
43 def prepare_data(self):
44 metadata = pd.read_csv(self.data_file)
45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
46 prompts = metadata['prompt'].values
47 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
48 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
49 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
50
51 def setup(self, stage=None):
52 valid_set_size = int(len(self.data_full) * 0.2)
53 if self.valid_set_size:
54 valid_set_size = min(valid_set_size, self.valid_set_size)
55 valid_set_size = max(valid_set_size, 1)
56 train_set_size = len(self.data_full) - valid_set_size
57
58 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
59
60 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
61 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
62 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
63 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
64 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
65 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
66
67 def train_dataloader(self):
68 return self.train_dataloader_
69
70 def val_dataloader(self):
71 return self.val_dataloader_
72
73
74class CSVDataset(Dataset):
75 def __init__(self,
76 data,
77 tokenizer,
78 size=512,
79 repeats=1,
80 interpolation="bicubic",
81 placeholder_token="*",
82 center_crop=False,
83 batch_size=1,
84 ):
85
86 self.data = data
87 self.tokenizer = tokenizer
88 self.placeholder_token = placeholder_token
89 self.batch_size = batch_size
90 self.cache = {}
91
92 self.num_instance_images = len(self.data)
93 self._length = self.num_instance_images * repeats
94
95 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
96 "bilinear": transforms.InterpolationMode.BILINEAR,
97 "bicubic": transforms.InterpolationMode.BICUBIC,
98 "lanczos": transforms.InterpolationMode.LANCZOS,
99 }[interpolation]
100 self.image_transforms = transforms.Compose(
101 [
102 transforms.Resize(size, interpolation=self.interpolation),
103 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
104 transforms.RandomHorizontalFlip(),
105 transforms.ToTensor(),
106 transforms.Normalize([0.5], [0.5]),
107 ]
108 )
109
110 def __len__(self):
111 return math.ceil(self._length / self.batch_size) * self.batch_size
112
113 def get_example(self, i):
114 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
115
116 if image_path in self.cache:
117 return self.cache[image_path]
118
119 example = {}
120
121 instance_image = Image.open(image_path)
122 if not instance_image.mode == "RGB":
123 instance_image = instance_image.convert("RGB")
124
125 prompt = prompt.format(self.placeholder_token)
126
127 example["prompts"] = prompt
128 example["nprompts"] = nprompt
129 example["pixel_values"] = instance_image
130 example["input_ids"] = self.tokenizer(
131 prompt,
132 padding="max_length",
133 truncation=True,
134 max_length=self.tokenizer.model_max_length,
135 return_tensors="pt",
136 ).input_ids[0]
137
138 self.cache[image_path] = example
139 return example
140
141 def __getitem__(self, i):
142 example = {}
143 unprocessed_example = self.get_example(i)
144
145 example["prompts"] = unprocessed_example["prompts"]
146 example["nprompts"] = unprocessed_example["nprompts"]
147 example["input_ids"] = unprocessed_example["input_ids"]
148 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
149
150 return example