From 14beba63391e1ddc9a145bb638d9306086ad1a5c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 08:34:07 +0200 Subject: Training: Create multiple class images per training image --- dreambooth.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 0e69d79..24e6091 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -65,6 +65,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--num_class_images", + type=int, + default=2, + help="How many class images to generate per training image." + ) parser.add_argument( "--repeats", type=int, @@ -347,6 +353,7 @@ class Checkpointer: scheduler=scheduler, ).to(self.accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() @@ -494,7 +501,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.class_identifier is not None and "class_prompt_ids" in examples[0]: + if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -518,6 +525,7 @@ def main(): instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="db_cls", + num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, @@ -528,7 +536,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.class_identifier is not None: + if args.num_class_images != 0: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -547,6 +555,7 @@ def main(): scheduler=scheduler, ).to(accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) for batch in batched_data: image_name = [p[1] for p in batch] @@ -645,11 +654,18 @@ def main(): 0, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process) + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) local_progress_bar.set_description("Batch X out of Y") - global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar = tqdm( + range(args.max_train_steps + val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) global_progress_bar.set_description("Total progress") try: @@ -686,7 +702,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.class_identifier is not None: + if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) -- cgit v1.2.3-54-g00ecf