summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py18
1 files changed, 1 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py
index a3d1f08..43b03ac 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -342,7 +342,6 @@ def loss_step(
342 schedule_sampler: ScheduleSampler, 342 schedule_sampler: ScheduleSampler,
343 unet: UNet2DConditionModel, 343 unet: UNet2DConditionModel,
344 text_encoder: CLIPTextModel, 344 text_encoder: CLIPTextModel,
345 guidance_scale: float,
346 prior_loss_weight: float, 345 prior_loss_weight: float,
347 seed: int, 346 seed: int,
348 input_pertubation: float, 347 input_pertubation: float,
@@ -400,19 +399,6 @@ def loss_step(
400 noisy_latents, timesteps, encoder_hidden_states, return_dict=False 399 noisy_latents, timesteps, encoder_hidden_states, return_dict=False
401 )[0] 400 )[0]
402 401
403 if guidance_scale != 0:
404 uncond_encoder_hidden_states = get_extended_embeddings(
405 text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"]
406 )
407 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
408
409 model_pred_uncond = unet(
410 noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False
411 )[0]
412 model_pred = model_pred_uncond + guidance_scale * (
413 model_pred - model_pred_uncond
414 )
415
416 # Get the target for loss depending on the prediction type 402 # Get the target for loss depending on the prediction type
417 if noise_scheduler.config.prediction_type == "epsilon": 403 if noise_scheduler.config.prediction_type == "epsilon":
418 target = noise 404 target = noise
@@ -425,7 +411,7 @@ def loss_step(
425 411
426 acc = (model_pred == target).float().mean() 412 acc = (model_pred == target).float().mean()
427 413
428 if guidance_scale == 0 and prior_loss_weight != 0: 414 if prior_loss_weight != 0:
429 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 415 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
430 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 416 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
431 target, target_prior = torch.chunk(target, 2, dim=0) 417 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -727,7 +713,6 @@ def train(
727 milestone_checkpoints: bool = True, 713 milestone_checkpoints: bool = True,
728 cycle: int = 1, 714 cycle: int = 1,
729 global_step_offset: int = 0, 715 global_step_offset: int = 0,
730 guidance_scale: float = 0.0,
731 prior_loss_weight: float = 1.0, 716 prior_loss_weight: float = 1.0,
732 input_pertubation: float = 0.1, 717 input_pertubation: float = 0.1,
733 schedule_sampler: Optional[ScheduleSampler] = None, 718 schedule_sampler: Optional[ScheduleSampler] = None,
@@ -787,7 +772,6 @@ def train(
787 schedule_sampler, 772 schedule_sampler,
788 unet, 773 unet,
789 text_encoder, 774 text_encoder,
790 guidance_scale,
791 prior_loss_weight, 775 prior_loss_weight,
792 seed, 776 seed,
793 input_pertubation, 777 input_pertubation,