From 633d890e4964e070be9b0a5b299c2f2e51d4b055 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 12:27:53 +0200 Subject: Upstream updates; better handling of textual embedding --- dreambooth.py | 26 ++++---- dreambooth_plus.py | 69 +++++++++++++--------- .../stable_diffusion/vlpn_stable_diffusion.py | 2 +- schedulers/scheduling_euler_a.py | 3 - textual_inversion.py | 41 +++++++------ 5 files changed, 82 insertions(+), 59 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 42d3980..770ad38 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -430,7 +430,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -537,6 +537,12 @@ def main(): num_train_timesteps=args.noise_timesteps ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -549,7 +555,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -651,8 +657,8 @@ def main(): ) # Move text_encoder and vae to device - text_encoder.to(accelerator.device) - vae.to(accelerator.device) + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() @@ -738,7 +744,7 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, @@ -761,15 +767,15 @@ def main(): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -818,7 +824,7 @@ def main(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -832,7 +838,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = loss.detach().item() val_loss += loss diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 73225de..a98417f 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -56,9 +56,9 @@ def parse_args(): help="A CSV file containing the training data." ) parser.add_argument( - "--placeholder_token", + "--instance_identifier", type=str, - default="<*>", + default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( @@ -67,6 +67,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--placeholder_token", + type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) parser.add_argument( "--initializer_token", type=str, @@ -118,7 +124,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1200, + default=1500, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -135,13 +141,13 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=5e-5, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, - default=1e-6, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -353,6 +359,7 @@ class Checkpointer: ema_unet, tokenizer, text_encoder, + instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -368,6 +375,7 @@ class Checkpointer: self.ema_unet = ema_unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -461,7 +469,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) + prompt = [prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] @@ -476,7 +484,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -522,28 +530,26 @@ def main(): if args.seed is not None: set_seed(args.seed) + args.instance_identifier = args.instance_identifier.format(args.placeholder_token) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") + initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens - if len(initializer_token_ids) > 1: - raise ValueError( - f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") + print(f"Re-using existing token {args.placeholder_token}.") + else: + print(f"Training new token {args.placeholder_token}.") - initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion @@ -630,6 +636,12 @@ def main(): num_train_timesteps=args.noise_timesteps ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -642,7 +654,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -658,7 +670,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_identifier=args.placeholder_token, + instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, @@ -744,7 +756,7 @@ def main(): ) # Move vae and unet to device - vae.to(accelerator.device) + vae.to(accelerator.device, dtype=weight_dtype) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -785,6 +797,7 @@ def main(): ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, + instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, @@ -830,7 +843,7 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, @@ -853,15 +866,15 @@ def main(): noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) @@ -933,7 +946,7 @@ def main(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -947,7 +960,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") loss = loss.detach().item() val_loss += loss diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2656b28..8b08a6f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 6abe971..c097a8a 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -47,7 +47,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", num_inference_steps=None, device='cuda' ): @@ -63,7 +62,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.device = device - self.tensor_format = tensor_format self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -77,7 +75,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): # get sigmas self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) - self.set_format(tensor_format=tensor_format) # A# take number of steps as input # A# store 1) number of steps 2) timesteps 3) schedule diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -55,9 +55,9 @@ def parse_args(): help="A CSV file containing the training data." ) parser.add_argument( - "--placeholder_token", + "--instance_identifier", type=str, - default="<*>", + default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( @@ -66,6 +66,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--placeholder_token", + type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) parser.add_argument( "--initializer_token", type=str, @@ -333,6 +339,7 @@ class Checkpointer: unet, tokenizer, text_encoder, + instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -347,6 +354,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -413,7 +421,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) + prompt = [prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] @@ -428,7 +436,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -480,28 +488,26 @@ def main(): if args.seed is not None: set_seed(args.seed) + args.instance_identifier = args.instance_identifier.format(args.placeholder_token) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") + initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens - if len(initializer_token_ids) > 1: - raise ValueError( - f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") + print(f"Re-using existing token {args.placeholder_token}.") + else: + print(f"Training new token {args.placeholder_token}.") - initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion @@ -602,7 +608,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_identifier=args.placeholder_token, + instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, @@ -730,6 +736,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, -- cgit v1.2.3-70-g09d2