diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 18 |
1 files changed, 1 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -342,7 +342,6 @@ def loss_step( | |||
342 | schedule_sampler: ScheduleSampler, | 342 | schedule_sampler: ScheduleSampler, |
343 | unet: UNet2DConditionModel, | 343 | unet: UNet2DConditionModel, |
344 | text_encoder: CLIPTextModel, | 344 | text_encoder: CLIPTextModel, |
345 | guidance_scale: float, | ||
346 | prior_loss_weight: float, | 345 | prior_loss_weight: float, |
347 | seed: int, | 346 | seed: int, |
348 | input_pertubation: float, | 347 | input_pertubation: float, |
@@ -400,19 +399,6 @@ def loss_step( | |||
400 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | 399 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False |
401 | )[0] | 400 | )[0] |
402 | 401 | ||
403 | if guidance_scale != 0: | ||
404 | uncond_encoder_hidden_states = get_extended_embeddings( | ||
405 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] | ||
406 | ) | ||
407 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
408 | |||
409 | model_pred_uncond = unet( | ||
410 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False | ||
411 | )[0] | ||
412 | model_pred = model_pred_uncond + guidance_scale * ( | ||
413 | model_pred - model_pred_uncond | ||
414 | ) | ||
415 | |||
416 | # Get the target for loss depending on the prediction type | 402 | # Get the target for loss depending on the prediction type |
417 | if noise_scheduler.config.prediction_type == "epsilon": | 403 | if noise_scheduler.config.prediction_type == "epsilon": |
418 | target = noise | 404 | target = noise |
@@ -425,7 +411,7 @@ def loss_step( | |||
425 | 411 | ||
426 | acc = (model_pred == target).float().mean() | 412 | acc = (model_pred == target).float().mean() |
427 | 413 | ||
428 | if guidance_scale == 0 and prior_loss_weight != 0: | 414 | if prior_loss_weight != 0: |
429 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 415 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
430 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 416 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
431 | target, target_prior = torch.chunk(target, 2, dim=0) | 417 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -727,7 +713,6 @@ def train( | |||
727 | milestone_checkpoints: bool = True, | 713 | milestone_checkpoints: bool = True, |
728 | cycle: int = 1, | 714 | cycle: int = 1, |
729 | global_step_offset: int = 0, | 715 | global_step_offset: int = 0, |
730 | guidance_scale: float = 0.0, | ||
731 | prior_loss_weight: float = 1.0, | 716 | prior_loss_weight: float = 1.0, |
732 | input_pertubation: float = 0.1, | 717 | input_pertubation: float = 0.1, |
733 | schedule_sampler: Optional[ScheduleSampler] = None, | 718 | schedule_sampler: Optional[ScheduleSampler] = None, |
@@ -787,7 +772,6 @@ def train( | |||
787 | schedule_sampler, | 772 | schedule_sampler, |
788 | unet, | 773 | unet, |
789 | text_encoder, | 774 | text_encoder, |
790 | guidance_scale, | ||
791 | prior_loss_weight, | 775 | prior_loss_weight, |
792 | seed, | 776 | seed, |
793 | input_pertubation, | 777 | input_pertubation, |