From b8ba49fe4c44aaaa30894e5abba22d3bbf94a562 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 28 Nov 2022 13:23:05 +0100 Subject: Fixed noise calculation for v-prediction --- textual_inversion.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 20b1617..fa7ae42 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -439,7 +439,7 @@ class Checkpointer: with torch.autocast("cuda"), torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) @@ -568,10 +568,6 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data original_token_embeds = token_embeds.detach().clone().to(accelerator.device) - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings if args.resume_checkpoint is not None: token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] @@ -817,6 +813,18 @@ def main(): ) global_progress_bar.set_description("Total progress") + def get_loss(noise_pred, noise, latents, timesteps): + if noise_scheduler.config.prediction_type == "v_prediction": + timesteps = timesteps.view(-1, 1, 1, 1) + alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps] + alpha_t = torch.sqrt(alphas_cumprod) + sigma_t = torch.sqrt(1 - alphas_cumprod) + target = alpha_t * noise - sigma_t * latents + else: + target = noise + + return F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -854,19 +862,20 @@ def main(): if args.num_class_images != 0: # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + latents, latents_prior = torch.chunk(noise_pred, 2, dim=0) noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = get_loss(noise_pred, noise, latents, timesteps) # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = get_loss(noise_pred_prior, noise_prior, latents_prior, timesteps) # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = get_loss(noise_pred, noise, latents, timesteps) accelerator.backward(loss) @@ -922,6 +931,8 @@ def main(): accelerator.wait_for_everyone() + print(token_embeds[placeholder_token_id]) + text_encoder.eval() val_loss = 0.0 @@ -945,7 +956,7 @@ def main(): noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = get_loss(noise_pred, noise, latents, timesteps) loss = loss.detach().item() val_loss += loss -- cgit v1.2.3-54-g00ecf