From 46b6c09a18b41edff77c6881529b66733d788abe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 21:28:52 +0200 Subject: Dreambooth: Generate specialized class images from input prompts --- data/dreambooth/csv.py | 112 +++++++++++++++++++++------------------------- data/dreambooth/prompt.py | 4 +- 2 files changed, 55 insertions(+), 61 deletions(-) (limited to 'data/dreambooth') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index c0b0067..4ebdc13 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule): batch_size, data_file, tokenizer, - instance_prompt, - class_data_root=None, - class_prompt=None, + instance_identifier, + class_identifier=None, size=512, repeats=100, interpolation="bicubic", - identifier="*", center_crop=False, valid_set_size=None, generator=None, @@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent + self.class_root = self.data_root.joinpath("db_cls") + self.class_root.mkdir(parents=True, exist_ok=True) + self.tokenizer = tokenizer - self.instance_prompt = instance_prompt - self.class_data_root = class_data_root - self.class_prompt = class_prompt + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier self.size = size self.repeats = repeats - self.identifier = identifier self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) - image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] + instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] + class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) - self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [(i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x"] def setup(self, stage=None): - valid_set_size = int(len(self.data_full) * 0.2) + valid_set_size = int(len(self.data) * 0.2) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) - train_set_size = len(self.data_full) - valid_set_size - - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) - - train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, - class_data_root=self.class_data_root, class_prompt=self.class_prompt, - size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) - val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, - size=self.size, interpolation=self.interpolation, identifier=self.identifier, - center_crop=self.center_crop, batch_size=self.batch_size) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, + valid_set_size = max(valid_set_size, 1) + train_set_size = len(self.data) - valid_set_size + + self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) + + train_dataset = CSVDataset(self.data_train, self.tokenizer, + instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + val_dataset = CSVDataset(self.data_val, self.tokenizer, + instance_identifier=self.instance_identifier, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats) + self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, shuffle=True, pin_memory=True, collate_fn=self.collate_fn) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, + self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, pin_memory=True, collate_fn=self.collate_fn) def train_dataloader(self): @@ -85,39 +90,23 @@ class CSVDataset(Dataset): def __init__(self, data, tokenizer, - instance_prompt, - class_data_root=None, - class_prompt=None, + instance_identifier, + class_identifier=None, size=512, repeats=1, interpolation="bicubic", - identifier="*", center_crop=False, - batch_size=1, ): self.data = data self.tokenizer = tokenizer - self.instance_prompt = instance_prompt - self.identifier = identifier - self.batch_size = batch_size + self.instance_identifier = instance_identifier + self.class_identifier = class_identifier self.cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - if class_data_root is not None: - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - - self.class_images = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images) - self._length = max(self.num_class_images, self.num_instance_images) - - self.class_prompt = class_prompt - else: - self.class_data_root = None - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, "bilinear": transforms.InterpolationMode.BILINEAR, "bicubic": transforms.InterpolationMode.BICUBIC, @@ -134,46 +123,49 @@ class CSVDataset(Dataset): ) def __len__(self): - return math.ceil(self._length / self.batch_size) * self.batch_size + return self._length def get_example(self, i): - image_path, prompt, nprompt = self.data[i % self.num_instance_images] + instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] - if image_path in self.cache: - return self.cache[image_path] + if instance_image_path in self.cache: + return self.cache[instance_image_path] example = {} - instance_image = Image.open(image_path) + example["prompts"] = prompt + example["nprompts"] = nprompt + + instance_image = Image.open(instance_image_path) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - prompt = prompt.format(self.identifier) + instance_prompt = prompt.format(self.instance_identifier) - example["prompts"] = prompt - example["nprompts"] = nprompt example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, + instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_data_root: - class_image = Image.open(self.class_images[i % self.num_class_images]) + if self.class_identifier: + class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") + class_prompt = prompt.format(self.class_identifier) + example["class_images"] = class_image example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, + class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids - self.cache[image_path] = example + self.cache[instance_image_path] = example return example def __getitem__(self, i): @@ -185,7 +177,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_data_root: + if self.class_identifier: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py @@ -2,8 +2,9 @@ from torch.utils.data import Dataset class PromptDataset(Dataset): - def __init__(self, prompt, num_samples): + def __init__(self, prompt, nprompt, num_samples): self.prompt = prompt + self.nprompt = nprompt self.num_samples = num_samples def __len__(self): @@ -12,5 +13,6 @@ class PromptDataset(Dataset): def __getitem__(self, index): example = {} example["prompt"] = self.prompt + example["nprompt"] = self.nprompt example["index"] = index return example -- cgit v1.2.3-70-g09d2