From c90099f06e0b461660b326fb6d86b69d86e78289 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 14:47:01 +0200 Subject: Added negative prompt support for training scripts --- textual_inversion.py | 83 +++++++++++++++------------------------------------- 1 file changed, 23 insertions(+), 60 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -207,16 +207,10 @@ def parse_args(): help="Size of sample images", ) parser.add_argument( - "--stable_sample_batches", + "--sample_batches", type=int, default=1, - help="Number of fixed seed sample batches to generate per checkpoint", - ) - parser.add_argument( - "--random_sample_batches", - type=int, - default=1, - help="Number of random seed sample batches to generate per checkpoint", + help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", @@ -319,9 +313,8 @@ class Checkpointer: placeholder_token_id, output_dir, sample_image_size, - random_sample_batches, + sample_batches, sample_batch_size, - stable_sample_batches, seed ): self.datamodule = datamodule @@ -334,9 +327,8 @@ class Checkpointer: self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed - self.random_sample_batches = random_sample_batches + self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size - self.stable_sample_batches = stable_sample_batches @torch.no_grad() def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): @@ -385,63 +377,33 @@ class Checkpointer: train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() - if self.stable_sample_batches > 0: - stable_latents = torch.randn( - (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), - ) - - all_samples = [] - file_path = samples_path.joinpath("stable", f"step_{step}.png") - file_path.parent.mkdir(parents=True, exist_ok=True) - - data_enum = enumerate(val_data) - - # Generate and save stable samples - for i in range(0, self.stable_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - latents=stable_latents[:len(prompt)], - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] - - all_samples += samples - - del samples - - image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) - image_grid.save(file_path) - - del all_samples - del image_grid - del stable_latents + generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) + stable_latents = torch.randn( + (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=generator, + ) - for data, pool in [(val_data, "val"), (train_data, "train")]: + for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.png") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(self.sample_batches): + batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] + prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, + negative_prompt=nprompt, height=self.sample_image_size, width=self.sample_image_size, + latents=latents[:len(prompt)] if latents is not None else None, + generator=generator if latents is not None else None, guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, @@ -452,7 +414,7 @@ class Checkpointer: del samples - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) image_grid.save(file_path) del all_samples @@ -461,6 +423,8 @@ class Checkpointer: del unwrapped del scheduler del pipeline + del generator + del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -603,7 +567,7 @@ def main(): placeholder_token=args.placeholder_token, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.stable_sample_batches + valid_set_size=args.sample_batch_size*args.sample_batches ) datamodule.prepare_data() @@ -623,8 +587,7 @@ def main(): output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, - random_sample_batches=args.random_sample_batches, - stable_sample_batches=args.stable_sample_batches, + sample_batches=args.sample_batches, seed=args.seed or torch.random.seed() ) -- cgit v1.2.3-54-g00ecf