From 14beba63391e1ddc9a145bb638d9306086ad1a5c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 08:34:07 +0200 Subject: Training: Create multiple class images per training image --- data/csv.py | 54 ++++++++++++++++++++++++++++++++++++++-------------- dreambooth.py | 28 +++++++++++++++++++++------ infer.py | 5 ++++- textual_inversion.py | 48 +++++++++++++++++++++++++++++++++------------- 4 files changed, 101 insertions(+), 34 deletions(-) diff --git a/data/csv.py b/data/csv.py index abd329d..dcaf7d3 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,5 +1,3 @@ -import math -import os import pandas as pd from pathlib import Path import pytorch_lightning as pl @@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule): instance_identifier, class_identifier=None, class_subdir="db_cls", + num_class_images=2, size=512, repeats=100, interpolation="bicubic", @@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule): self.data_root = self.data_file.parent self.class_root = self.data_root.joinpath(class_subdir) self.class_root.mkdir(parents=True, exist_ok=True) + self.num_class_images = num_class_images self.tokenizer = tokenizer self.instance_identifier = instance_identifier @@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_csv(self.data_file) - instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] - class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] - prompts = metadata['prompt'].values - nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) - skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) - self.data = [(i, c, p, n) - for i, c, p, n, s - in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) - if s != "x"] + instance_image_paths = [ + self.data_root.joinpath(f) + for f in metadata['image'].values + for i in range(self.num_class_images) + ] + class_image_paths = [ + self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") + for f in metadata['image'].values + for i in range(self.num_class_images) + ] + prompts = [ + prompt + for prompt in metadata['prompt'].values + for i in range(self.num_class_images) + ] + nprompts = [ + nprompt + for nprompt in metadata['nprompt'].values + for i in range(self.num_class_images) + ] if 'nprompt' in metadata else [""] * len(instance_image_paths) + skips = [ + skip + for skip in metadata['skip'].values + for i in range(self.num_class_images) + ] if 'skip' in metadata else [""] * len(instance_image_paths) + self.data = [ + (i, c, p, n) + for i, c, p, n, s + in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) + if s != "x" + ] def setup(self, stage=None): valid_set_size = int(len(self.data) * 0.2) @@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule): train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, + num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats) val_dataset = CSVDataset(self.data_val, self.tokenizer, @@ -93,6 +116,7 @@ class CSVDataset(Dataset): tokenizer, instance_identifier, class_identifier=None, + num_class_images=2, size=512, repeats=1, interpolation="bicubic", @@ -103,6 +127,7 @@ class CSVDataset(Dataset): self.tokenizer = tokenizer self.instance_identifier = instance_identifier self.class_identifier = class_identifier + self.num_class_images = num_class_images self.cache = {} self.num_instance_images = len(self.data) @@ -128,9 +153,10 @@ class CSVDataset(Dataset): def get_example(self, i): instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] + cache_key = f"{instance_image_path}_{class_image_path}" - if instance_image_path in self.cache: - return self.cache[instance_image_path] + if cache_key in self.cache: + return self.cache[cache_key] example = {} @@ -149,7 +175,7 @@ class CSVDataset(Dataset): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_identifier is not None: + if self.num_class_images != 0: class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") diff --git a/dreambooth.py b/dreambooth.py index 0e69d79..24e6091 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -65,6 +65,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--num_class_images", + type=int, + default=2, + help="How many class images to generate per training image." + ) parser.add_argument( "--repeats", type=int, @@ -347,6 +353,7 @@ class Checkpointer: scheduler=scheduler, ).to(self.accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() @@ -494,7 +501,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.class_identifier is not None and "class_prompt_ids" in examples[0]: + if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -518,6 +525,7 @@ def main(): instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="db_cls", + num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, @@ -528,7 +536,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.class_identifier is not None: + if args.num_class_images != 0: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -547,6 +555,7 @@ def main(): scheduler=scheduler, ).to(accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) for batch in batched_data: image_name = [p[1] for p in batch] @@ -645,11 +654,18 @@ def main(): 0, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process) + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) local_progress_bar.set_description("Batch X out of Y") - global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar = tqdm( + range(args.max_train_steps + val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) global_progress_bar.set_description("Total progress") try: @@ -686,7 +702,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.class_identifier is not None: + if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) diff --git a/infer.py b/infer.py index 34e570a..6197aa3 100644 --- a/infer.py +++ b/infer.py @@ -262,7 +262,10 @@ def generate(output_dir, pipeline, args): with autocast("cuda"): for i in range(args.batch_num): - pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") + pipeline.set_progress_bar_config( + desc=f"Batch {i + 1} of {args.batch_num}", + dynamic_ncols=True + ) generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( diff --git a/textual_inversion.py b/textual_inversion.py index 11c324d..86fcdfe 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -68,16 +68,17 @@ def parse_args(): help="A token to use as initializer word." ) parser.add_argument( - "--use_class_images", - action="store_true", - default=True, - help="Include class images in the loss calculation a la Dreambooth.", + "--num_class_images", + type=int, + default=2, + help="How many class images to generate per training image." ) parser.add_argument( "--repeats", type=int, default=100, - help="How many times to repeat the training data.") + help="How many times to repeat the training data." + ) parser.add_argument( "--output_dir", type=str, @@ -203,6 +204,12 @@ def parse_args(): default=500, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--sample_frequency", + type=int, + default=100, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_image_size", type=int, @@ -381,6 +388,7 @@ class Checkpointer: scheduler=scheduler, ).to(self.accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() @@ -577,7 +585,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.use_class_images and "class_prompt_ids" in examples[0]: + if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -599,8 +607,9 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token if args.use_class_images else None, + class_identifier=args.initializer_token, class_subdir="ti_cls", + num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, @@ -611,7 +620,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.use_class_images: + if args.num_class_images != 0: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -630,6 +639,7 @@ def main(): scheduler=scheduler, ).to(accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) for batch in batched_data: image_name = [p[1] for p in batch] @@ -729,11 +739,18 @@ def main(): text_encoder, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process) + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) local_progress_bar.set_description("Batch X out of Y") - global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar = tqdm( + range(args.max_train_steps + val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) global_progress_bar.set_description("Total progress") try: @@ -744,6 +761,8 @@ def main(): text_encoder.train() train_loss = 0.0 + sample_checkpoint = False + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space @@ -769,7 +788,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.use_class_images: + if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) @@ -812,6 +831,9 @@ def main(): global_step += 1 + if global_step % args.sample_frequency == 0: + sample_checkpoint = True + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: local_progress_bar.clear() global_progress_bar.clear() @@ -878,7 +900,7 @@ def main(): checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) min_val_loss = val_loss - if accelerator.is_main_process: + if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( global_step + global_step_offset, text_encoder, -- cgit v1.2.3-54-g00ecf