From 14beba63391e1ddc9a145bb638d9306086ad1a5c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 08:34:07 +0200 Subject: Training: Create multiple class images per training image --- textual_inversion.py | 48 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 13 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 11c324d..86fcdfe 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -68,16 +68,17 @@ def parse_args(): help="A token to use as initializer word." ) parser.add_argument( - "--use_class_images", - action="store_true", - default=True, - help="Include class images in the loss calculation a la Dreambooth.", + "--num_class_images", + type=int, + default=2, + help="How many class images to generate per training image." ) parser.add_argument( "--repeats", type=int, default=100, - help="How many times to repeat the training data.") + help="How many times to repeat the training data." + ) parser.add_argument( "--output_dir", type=str, @@ -203,6 +204,12 @@ def parse_args(): default=500, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--sample_frequency", + type=int, + default=100, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_image_size", type=int, @@ -381,6 +388,7 @@ class Checkpointer: scheduler=scheduler, ).to(self.accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() val_data = self.datamodule.val_dataloader() @@ -577,7 +585,7 @@ def main(): pixel_values = [example["instance_images"] for example in examples] # concat class and instance examples for prior preservation - if args.use_class_images and "class_prompt_ids" in examples[0]: + if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -599,8 +607,9 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token if args.use_class_images else None, + class_identifier=args.initializer_token, class_subdir="ti_cls", + num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, @@ -611,7 +620,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - if args.use_class_images: + if args.num_class_images != 0: missing_data = [item for item in datamodule.data if not item[1].exists()] if len(missing_data) != 0: @@ -630,6 +639,7 @@ def main(): scheduler=scheduler, ).to(accelerator.device) pipeline.enable_attention_slicing() + pipeline.set_progress_bar_config(dynamic_ncols=True) for batch in batched_data: image_name = [p[1] for p in batch] @@ -729,11 +739,18 @@ def main(): text_encoder, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process) + local_progress_bar = tqdm( + range(num_update_steps_per_epoch + num_val_steps_per_epoch), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) local_progress_bar.set_description("Batch X out of Y") - global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) + global_progress_bar = tqdm( + range(args.max_train_steps + val_steps), + disable=not accelerator.is_local_main_process, + dynamic_ncols=True + ) global_progress_bar.set_description("Total progress") try: @@ -744,6 +761,8 @@ def main(): text_encoder.train() train_loss = 0.0 + sample_checkpoint = False + for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space @@ -769,7 +788,7 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - if args.use_class_images: + if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) @@ -812,6 +831,9 @@ def main(): global_step += 1 + if global_step % args.sample_frequency == 0: + sample_checkpoint = True + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: local_progress_bar.clear() global_progress_bar.clear() @@ -878,7 +900,7 @@ def main(): checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) min_val_loss = val_loss - if accelerator.is_main_process: + if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( global_step + global_step_offset, text_encoder, -- cgit v1.2.3-54-g00ecf