summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py8
-rw-r--r--training/functional.py12
2 files changed, 14 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py
index a6cd065..d52d251 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -104,11 +104,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool
104 nprompt_ids = [example["nprompt_ids"] for example in examples] 104 nprompt_ids = [example["nprompt_ids"] for example in examples]
105 105
106 input_ids = [example["instance_prompt_ids"] for example in examples] 106 input_ids = [example["instance_prompt_ids"] for example in examples]
107 negative_input_ids = [example["negative_prompt_ids"] for example in examples]
107 pixel_values = [example["instance_images"] for example in examples] 108 pixel_values = [example["instance_images"] for example in examples]
108 109
109 if with_guidance: 110 if with_prior_preservation:
110 input_ids += [example["negative_prompt_ids"] for example in examples]
111 elif with_prior_preservation:
112 input_ids += [example["class_prompt_ids"] for example in examples] 111 input_ids += [example["class_prompt_ids"] for example in examples]
113 pixel_values += [example["class_images"] for example in examples] 112 pixel_values += [example["class_images"] for example in examples]
114 113
@@ -118,13 +117,16 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool
118 prompts = unify_input_ids(tokenizer, prompt_ids) 117 prompts = unify_input_ids(tokenizer, prompt_ids)
119 nprompts = unify_input_ids(tokenizer, nprompt_ids) 118 nprompts = unify_input_ids(tokenizer, nprompt_ids)
120 inputs = unify_input_ids(tokenizer, input_ids) 119 inputs = unify_input_ids(tokenizer, input_ids)
120 negative_inputs = unify_input_ids(tokenizer, negative_input_ids)
121 121
122 batch = { 122 batch = {
123 "prompt_ids": prompts.input_ids, 123 "prompt_ids": prompts.input_ids,
124 "nprompt_ids": nprompts.input_ids, 124 "nprompt_ids": nprompts.input_ids,
125 "input_ids": inputs.input_ids, 125 "input_ids": inputs.input_ids,
126 "negative_input_ids": negative_inputs.attention_mask,
126 "pixel_values": pixel_values, 127 "pixel_values": pixel_values,
127 "attention_mask": inputs.attention_mask, 128 "attention_mask": inputs.attention_mask,
129 "negative_attention_mask": negative_inputs.attention_mask,
128 } 130 }
129 131
130 return batch 132 return batch
diff --git a/training/functional.py b/training/functional.py
index d285366..109845b 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -344,9 +344,15 @@ def loss_step(
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345 345
346 if guidance_scale != 0: 346 if guidance_scale != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 347 uncond_encoder_hidden_states = get_extended_embeddings(
348 model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) 348 text_encoder,
349 model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) 349 batch["negative_input_ids"],
350 batch["negative_attention_mask"]
351 )
352 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
353
354 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
355 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
350 356
351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 357 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
352 elif prior_loss_weight != 0: 358 elif prior_loss_weight != 0: