From 300deaa789a0321f32d5e7f04d9860eaa258110e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 19:22:22 +0200 Subject: Add Textual Inversion with class dataset (a la Dreambooth) --- data/dreambooth/csv.py | 11 ++++------- data/dreambooth/prompt.py | 18 ------------------ 2 files changed, 4 insertions(+), 25 deletions(-) delete mode 100644 data/dreambooth/prompt.py (limited to 'data') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 9075979..abd329d 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule): tokenizer, instance_identifier, class_identifier=None, + class_subdir="db_cls", size=512, repeats=100, interpolation="bicubic", @@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath("db_cls") + self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) self.tokenizer = tokenizer @@ -140,11 +141,9 @@ class CSVDataset(Dataset): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - instance_prompt = prompt.format(self.instance_identifier) - example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - instance_prompt, + prompt.format(self.instance_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, @@ -155,11 +154,9 @@ class CSVDataset(Dataset): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") - class_prompt = prompt.format(self.class_identifier) - example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - class_prompt, + prompt.format(self.class_identifier), padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py deleted file mode 100644 index b3a83ce..0000000 --- a/data/dreambooth/prompt.py +++ /dev/null @@ -1,18 +0,0 @@ -from torch.utils.data import Dataset - - -class PromptDataset(Dataset): - def __init__(self, prompt, nprompt, num_samples): - self.prompt = prompt - self.nprompt = nprompt - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, index): - example = {} - example["prompt"] = self.prompt - example["nprompt"] = self.nprompt - example["index"] = index - return example -- cgit v1.2.3-70-g09d2