From c90099f06e0b461660b326fb6d86b69d86e78289 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 14:47:01 +0200 Subject: Added negative prompt support for training scripts --- data/dreambooth/csv.py | 15 ++++---- data/textual_inversion/csv.py | 17 +++++---- dreambooth.py | 81 +++++++++++------------------------------ textual_inversion.py | 83 ++++++++++++------------------------------- 4 files changed, 63 insertions(+), 133 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - captions = [caption for caption in metadata['caption'].values] - skips = [skip for skip in metadata['skip'].values] - self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + prompts = metadata['prompt'].values + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) + self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) @@ -135,7 +136,7 @@ class CSVDataset(Dataset): return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): - image_path, text = self.data[i % self.num_instance_images] + image_path, prompt, nprompt = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] @@ -146,9 +147,10 @@ class CSVDataset(Dataset): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - text = text.format(self.identifier) + prompt = prompt.format(self.identifier) - example["prompts"] = text + example["prompts"] = prompt + example["nprompts"] = nprompt example["instance_images"] = instance_image example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, @@ -178,6 +180,7 @@ class CSVDataset(Dataset): unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] + example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] - captions = [caption for caption in metadata['caption'].values] - skips = [skip for skip in metadata['skip'].values] - self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] + prompts = metadata['prompt'].values + nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) + skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) + self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) @@ -109,7 +110,7 @@ class CSVDataset(Dataset): return math.ceil(self._length / self.batch_size) * self.batch_size def get_example(self, i): - image_path, text = self.data[i % self.num_instance_images] + image_path, prompt, nprompt = self.data[i % self.num_instance_images] if image_path in self.cache: return self.cache[image_path] @@ -120,12 +121,13 @@ class CSVDataset(Dataset): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - text = text.format(self.placeholder_token) + prompt = prompt.format(self.placeholder_token) - example["prompts"] = text + example["prompts"] = prompt + example["nprompts"] = nprompt example["pixel_values"] = instance_image example["input_ids"] = self.tokenizer( - text, + prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, @@ -140,6 +142,7 @@ class CSVDataset(Dataset): unprocessed_example = self.get_example(i) example["prompts"] = unprocessed_example["prompts"] + example["nprompts"] = unprocessed_example["nprompts"] example["input_ids"] = unprocessed_example["input_ids"] example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) diff --git a/dreambooth.py b/dreambooth.py index 75602dc..5fbf172 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -191,16 +191,10 @@ def parse_args(): help="Size of sample images", ) parser.add_argument( - "--stable_sample_batches", + "--sample_batches", type=int, default=1, - help="Number of fixed seed sample batches to generate per checkpoint", - ) - parser.add_argument( - "--random_sample_batches", - type=int, - default=1, - help="Number of random seed sample batches to generate per checkpoint", + help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", @@ -331,9 +325,8 @@ class Checkpointer: text_encoder, output_dir, sample_image_size, - random_sample_batches, + sample_batches, sample_batch_size, - stable_sample_batches, seed ): self.datamodule = datamodule @@ -345,9 +338,8 @@ class Checkpointer: self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed - self.random_sample_batches = random_sample_batches + self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - self.stable_sample_batches = stable_sample_batches @torch.no_grad() def checkpoint(self): @@ -396,63 +388,33 @@ class Checkpointer: train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() - if self.stable_sample_batches > 0: - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), - ) - - all_samples = [] - file_path = samples_path.joinpath("stable", f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(val_data) - - # Generate and save stable samples - for i in range(0, self.stable_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - latents=stable_latents[:len(prompt)], - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(file_path) - - del all_samples - del image_grid - del stable_latents + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=generator, + ) - for data, pool in [(val_data, "val"), (train_data, "train")]: + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, + negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, + latents=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, @@ -463,7 +425,7 @@ class Checkpointer: del samples - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) image_grid.save(file_path) del all_samples @@ -630,7 +592,7 @@ def main(): identifier=args.identifier, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.stable_sample_batches, + valid_set_size=args.sample_batch_size*args.sample_batches, collate_fn=collate_fn) datamodule.prepare_data() @@ -649,8 +611,7 @@ def main(): output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, - random_sample_batches=args.random_sample_batches, - stable_sample_batches=args.stable_sample_batches, + sample_batches=args.sample_batches, seed=args.seed or torch.random.seed() ) diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -207,16 +207,10 @@ def parse_args(): help="Size of sample images", ) parser.add_argument( - "--stable_sample_batches", + "--sample_batches", type=int, default=1, - help="Number of fixed seed sample batches to generate per checkpoint", - ) - parser.add_argument( - "--random_sample_batches", - type=int, - default=1, - help="Number of random seed sample batches to generate per checkpoint", + help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", @@ -319,9 +313,8 @@ class Checkpointer: placeholder_token_id, output_dir, sample_image_size, - random_sample_batches, + sample_batches, sample_batch_size, - stable_sample_batches, seed ): self.datamodule = datamodule @@ -334,9 +327,8 @@ class Checkpointer: self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed - self.random_sample_batches = random_sample_batches + self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - self.stable_sample_batches = stable_sample_batches @torch.no_grad() def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): @@ -385,63 +377,33 @@ class Checkpointer: train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() - if self.stable_sample_batches > 0: - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), - ) - - all_samples = [] - file_path = samples_path.joinpath("stable", f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(val_data) - - # Generate and save stable samples - for i in range(0, self.stable_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - latents=stable_latents[:len(prompt)], - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(file_path) - - del all_samples - del image_grid - del stable_latents + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=generator, + ) - for data, pool in [(val_data, "val"), (train_data, "train")]: + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, + negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, + latents=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, @@ -452,7 +414,7 @@ class Checkpointer: del samples - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) image_grid.save(file_path) del all_samples @@ -461,6 +423,8 @@ class Checkpointer: del unwrapped del scheduler del pipeline + del generator + del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -603,7 +567,7 @@ def main(): placeholder_token=args.placeholder_token, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.stable_sample_batches + valid_set_size=args.sample_batch_size*args.sample_batches ) datamodule.prepare_data() @@ -623,8 +587,7 @@ def main(): output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, - random_sample_batches=args.random_sample_batches, - stable_sample_batches=args.stable_sample_batches, + sample_batches=args.sample_batches, seed=args.seed or torch.random.seed() ) -- cgit v1.2.3-70-g09d2