From 46b1eda6d1c7db552ce5c577bed101c61f09d55b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 22 Oct 2022 16:56:10 +0200 Subject: Revert lat; fix skip attribute in dataset --- data/csv.py | 2 +- dreambooth.py | 161 +++++++++++++++++++++++++++++++--------------------------- 2 files changed, 87 insertions(+), 76 deletions(-) diff --git a/data/csv.py b/data/csv.py index 4c91ded..df15c5a 100644 --- a/data/csv.py +++ b/data/csv.py @@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule): def prepare_data(self): metadata = pd.read_json(self.data_file) - metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] + metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] num_images = len(metadata) valid_set_size = int(num_images * 0.2) diff --git a/dreambooth.py b/dreambooth.py index 1539e81..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -810,75 +810,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - def run_step(batch, train=False, class_images=False): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if class_images: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() - - # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") - - # Add the prior loss to the instance loss. - loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - - if train: - accelerator.backward(loss) - - if args.initializer_token is not None: - # Keep the token embeddings fixed except the newly added - # embeddings for the concept, as we only want to optimize the concept embeddings - if accelerator.num_processes > 1: - token_embeds = text_encoder.module.get_input_embeddings().weight - else: - token_embeds = text_encoder.get_input_embeddings().weight - - # Get the index for tokens that we want to freeze - index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id - token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] - - if accelerator.sync_gradients: - params_to_clip = ( - unet.parameters() - if args.initializer_token is not None - else itertools.chain(unet.parameters(), text_encoder.parameters()) - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - loss = loss.detach().item() - return loss - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -892,11 +823,72 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(itertools.chain(unet, text_encoder)): - loss = run_step( - batch, - train=True, - class_images=args.num_class_images != 0 - ) + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.num_class_images != 0: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + accelerator.backward(loss) + + if args.initializer_token is not None: + # Keep the token embeddings fixed except the newly added + # embeddings for the concept, as we only want to optimize the concept embeddings + if accelerator.num_processes > 1: + token_embeds = text_encoder.module.get_input_embeddings().weight + else: + token_embeds = text_encoder.get_input_embeddings().weight + + # Get the index for tokens that we want to freeze + index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id + token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] + + if accelerator.sync_gradients: + params_to_clip = ( + unet.parameters() + if args.initializer_token is not None + else itertools.chain(unet.parameters(), text_encoder.parameters()) + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + loss = loss.detach().item() train_loss += loss # Checks if the accelerator has performed an optimization step behind the scenes @@ -937,7 +929,26 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss = run_step(batch) + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, + (bsz,), device=latents.device) + timesteps = timesteps.long() + + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) + + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + + loss = loss.detach().item() val_loss += loss if accelerator.sync_gradients: -- cgit v1.2.3-70-g09d2