summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
commit300deaa789a0321f32d5e7f04d9860eaa258110e (patch)
tree892e89753e5c4d86d787131595751bc03c610be8 /data
parentDefault sample steps 30 -> 40 (diff)
downloadtextual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.gz
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.bz2
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.zip
Add Textual Inversion with class dataset (a la Dreambooth)
Diffstat (limited to 'data')
-rw-r--r--data/dreambooth/csv.py11
-rw-r--r--data/dreambooth/prompt.py18
2 files changed, 4 insertions, 25 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 9075979..abd329d 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule):
15 tokenizer, 15 tokenizer,
16 instance_identifier, 16 instance_identifier,
17 class_identifier=None, 17 class_identifier=None,
18 class_subdir="db_cls",
18 size=512, 19 size=512,
19 repeats=100, 20 repeats=100,
20 interpolation="bicubic", 21 interpolation="bicubic",
@@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule):
30 raise ValueError("data_file must be a file") 31 raise ValueError("data_file must be a file")
31 32
32 self.data_root = self.data_file.parent 33 self.data_root = self.data_file.parent
33 self.class_root = self.data_root.joinpath("db_cls") 34 self.class_root = self.data_root.joinpath(class_subdir)
34 self.class_root.mkdir(parents=True, exist_ok=True) 35 self.class_root.mkdir(parents=True, exist_ok=True)
35 36
36 self.tokenizer = tokenizer 37 self.tokenizer = tokenizer
@@ -140,11 +141,9 @@ class CSVDataset(Dataset):
140 if not instance_image.mode == "RGB": 141 if not instance_image.mode == "RGB":
141 instance_image = instance_image.convert("RGB") 142 instance_image = instance_image.convert("RGB")
142 143
143 instance_prompt = prompt.format(self.instance_identifier)
144
145 example["instance_images"] = instance_image 144 example["instance_images"] = instance_image
146 example["instance_prompt_ids"] = self.tokenizer( 145 example["instance_prompt_ids"] = self.tokenizer(
147 instance_prompt, 146 prompt.format(self.instance_identifier),
148 padding="do_not_pad", 147 padding="do_not_pad",
149 truncation=True, 148 truncation=True,
150 max_length=self.tokenizer.model_max_length, 149 max_length=self.tokenizer.model_max_length,
@@ -155,11 +154,9 @@ class CSVDataset(Dataset):
155 if not class_image.mode == "RGB": 154 if not class_image.mode == "RGB":
156 class_image = class_image.convert("RGB") 155 class_image = class_image.convert("RGB")
157 156
158 class_prompt = prompt.format(self.class_identifier)
159
160 example["class_images"] = class_image 157 example["class_images"] = class_image
161 example["class_prompt_ids"] = self.tokenizer( 158 example["class_prompt_ids"] = self.tokenizer(
162 class_prompt, 159 prompt.format(self.class_identifier),
163 padding="do_not_pad", 160 padding="do_not_pad",
164 truncation=True, 161 truncation=True,
165 max_length=self.tokenizer.model_max_length, 162 max_length=self.tokenizer.model_max_length,
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
deleted file mode 100644
index b3a83ce..0000000
--- a/data/dreambooth/prompt.py
+++ /dev/null
@@ -1,18 +0,0 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, nprompt, num_samples):
6 self.prompt = prompt
7 self.nprompt = nprompt
8 self.num_samples = num_samples
9
10 def __len__(self):
11 return self.num_samples
12
13 def __getitem__(self, index):
14 example = {}
15 example["prompt"] = self.prompt
16 example["nprompt"] = self.nprompt
17 example["index"] = index
18 return example