diff options
| -rw-r--r-- | data/csv.py | 8 | ||||
| -rw-r--r-- | training/functional.py | 12 | 
2 files changed, 14 insertions, 6 deletions
| diff --git a/data/csv.py b/data/csv.py index a6cd065..d52d251 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -104,11 +104,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
| 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 104 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 
| 105 | 105 | ||
| 106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 106 | input_ids = [example["instance_prompt_ids"] for example in examples] | 
| 107 | negative_input_ids = [example["negative_prompt_ids"] for example in examples] | ||
| 107 | pixel_values = [example["instance_images"] for example in examples] | 108 | pixel_values = [example["instance_images"] for example in examples] | 
| 108 | 109 | ||
| 109 | if with_guidance: | 110 | if with_prior_preservation: | 
| 110 | input_ids += [example["negative_prompt_ids"] for example in examples] | ||
| 111 | elif with_prior_preservation: | ||
| 112 | input_ids += [example["class_prompt_ids"] for example in examples] | 111 | input_ids += [example["class_prompt_ids"] for example in examples] | 
| 113 | pixel_values += [example["class_images"] for example in examples] | 112 | pixel_values += [example["class_images"] for example in examples] | 
| 114 | 113 | ||
| @@ -118,13 +117,16 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool | |||
| 118 | prompts = unify_input_ids(tokenizer, prompt_ids) | 117 | prompts = unify_input_ids(tokenizer, prompt_ids) | 
| 119 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 118 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 
| 120 | inputs = unify_input_ids(tokenizer, input_ids) | 119 | inputs = unify_input_ids(tokenizer, input_ids) | 
| 120 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids) | ||
| 121 | 121 | ||
| 122 | batch = { | 122 | batch = { | 
| 123 | "prompt_ids": prompts.input_ids, | 123 | "prompt_ids": prompts.input_ids, | 
| 124 | "nprompt_ids": nprompts.input_ids, | 124 | "nprompt_ids": nprompts.input_ids, | 
| 125 | "input_ids": inputs.input_ids, | 125 | "input_ids": inputs.input_ids, | 
| 126 | "negative_input_ids": negative_inputs.attention_mask, | ||
| 126 | "pixel_values": pixel_values, | 127 | "pixel_values": pixel_values, | 
| 127 | "attention_mask": inputs.attention_mask, | 128 | "attention_mask": inputs.attention_mask, | 
| 129 | "negative_attention_mask": negative_inputs.attention_mask, | ||
| 128 | } | 130 | } | 
| 129 | 131 | ||
| 130 | return batch | 132 | return batch | 
| diff --git a/training/functional.py b/training/functional.py index d285366..109845b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -344,9 +344,15 @@ def loss_step( | |||
| 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 
| 345 | 345 | ||
| 346 | if guidance_scale != 0: | 346 | if guidance_scale != 0: | 
| 347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 347 | uncond_encoder_hidden_states = get_extended_embeddings( | 
| 348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | 348 | text_encoder, | 
| 349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | 349 | batch["negative_input_ids"], | 
| 350 | batch["negative_attention_mask"] | ||
| 351 | ) | ||
| 352 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
| 353 | |||
| 354 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | ||
| 355 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | ||
| 350 | 356 | ||
| 351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 357 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 
| 352 | elif prior_loss_weight != 0: | 358 | elif prior_loss_weight != 0: | 
