From 633d890e4964e070be9b0a5b299c2f2e51d4b055 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 12:27:53 +0200 Subject: Upstream updates; better handling of textual embedding --- dreambooth_plus.py | 69 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 28 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -56,9 +56,9 @@ def parse_args(): help="A CSV file containing the training data." ) parser.add_argument( - "--placeholder_token", + "--instance_identifier", type=str, - default="<*>", + default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( @@ -67,6 +67,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--placeholder_token", + type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) parser.add_argument( "--initializer_token", type=str, @@ -118,7 +124,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1200, + default=1500, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -135,13 +141,13 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=5e-5, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, - default=1e-6, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -353,6 +359,7 @@ class Checkpointer: ema_unet, tokenizer, text_encoder, + instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -368,6 +375,7 @@ class Checkpointer: self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -461,7 +469,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) + prompt = [prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] @@ -476,7 +484,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -522,28 +530,26 @@ def main(): if args.seed is not None: set_seed(args.seed) + args.instance_identifier = args.instance_identifier.format(args.placeholder_token) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") + initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens - if len(initializer_token_ids) > 1: - raise ValueError( - f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") + print(f"Re-using existing token {args.placeholder_token}.") + else: + print(f"Training new token {args.placeholder_token}.") - initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion @@ -630,6 +636,12 @@ def main(): num_train_timesteps=args.noise_timesteps ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -642,7 +654,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -658,7 +670,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_identifier=args.placeholder_token, + instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, @@ -744,7 +756,7 @@ def main(): ) # Move vae and unet to device - vae.to(accelerator.device) + vae.to(accelerator.device, dtype=weight_dtype) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -785,6 +797,7 @@ def main(): ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, + instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, @@ -830,7 +843,7 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, @@ -853,15 +866,15 @@ def main(): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) @@ -933,7 +946,7 @@ def main(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -947,7 +960,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf