From d69cc8f46f238e91e2f597cd301cc53b1d4b8bec Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 25 Mar 2023 17:49:30 +0100 Subject: Fix training with guidance --- data/csv.py | 8 +++++--- training/functional.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/data/csv.py b/data/csv.py index a6cd065..d52d251 100644 --- a/data/csv.py +++ b/data/csv.py @@ -104,11 +104,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] + negative_input_ids = [example["negative_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - if with_guidance: - input_ids += [example["negative_prompt_ids"] for example in examples] - elif with_prior_preservation: + if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -118,13 +117,16 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool prompts = unify_input_ids(tokenizer, prompt_ids) nprompts = unify_input_ids(tokenizer, nprompt_ids) inputs = unify_input_ids(tokenizer, input_ids) + negative_inputs = unify_input_ids(tokenizer, negative_input_ids) batch = { "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, + "negative_input_ids": negative_inputs.attention_mask, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, + "negative_attention_mask": negative_inputs.attention_mask, } return batch diff --git a/training/functional.py b/training/functional.py index d285366..109845b 100644 --- a/training/functional.py +++ b/training/functional.py @@ -344,9 +344,15 @@ def loss_step( raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if guidance_scale != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) - model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) + uncond_encoder_hidden_states = get_extended_embeddings( + text_encoder, + batch["negative_input_ids"], + batch["negative_attention_mask"] + ) + uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) + + model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample + model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") elif prior_loss_weight != 0: -- cgit v1.2.3-70-g09d2