From 11b7740deeef7903e81ba4c65a45853323a5fd5e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:34:04 +0200 Subject: Remove training guidance_scale --- training/functional.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py @@ -342,7 +342,6 @@ def loss_step( schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - guidance_scale: float, prior_loss_weight: float, seed: int, input_pertubation: float, @@ -400,19 +399,6 @@ def loss_step( noisy_latents, timesteps, encoder_hidden_states, return_dict=False )[0] - if guidance_scale != 0: - uncond_encoder_hidden_states = get_extended_embeddings( - text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] - ) - uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - - model_pred_uncond = unet( - noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False - )[0] - model_pred = model_pred_uncond + guidance_scale * ( - model_pred - model_pred_uncond - ) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -425,7 +411,7 @@ def loss_step( acc = (model_pred == target).float().mean() - if guidance_scale == 0 and prior_loss_weight != 0: + if prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -727,7 +713,6 @@ def train( milestone_checkpoints: bool = True, cycle: int = 1, global_step_offset: int = 0, - guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, input_pertubation: float = 0.1, schedule_sampler: Optional[ScheduleSampler] = None, @@ -787,7 +772,6 @@ def train( schedule_sampler, unet, text_encoder, - guidance_scale, prior_loss_weight, seed, input_pertubation, -- cgit v1.2.3-70-g09d2