diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-22 07:34:04 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-22 07:34:04 +0200 |
| commit | 11b7740deeef7903e81ba4c65a45853323a5fd5e (patch) | |
| tree | 0d5eec0e574447afe80538252884e0093742e3cb /training | |
| parent | Remove convnext (diff) | |
| download | textual-inversion-diff-11b7740deeef7903e81ba4c65a45853323a5fd5e.tar.gz textual-inversion-diff-11b7740deeef7903e81ba4c65a45853323a5fd5e.tar.bz2 textual-inversion-diff-11b7740deeef7903e81ba4c65a45853323a5fd5e.zip | |
Remove training guidance_scale
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 18 |
1 files changed, 1 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -342,7 +342,6 @@ def loss_step( | |||
| 342 | schedule_sampler: ScheduleSampler, | 342 | schedule_sampler: ScheduleSampler, |
| 343 | unet: UNet2DConditionModel, | 343 | unet: UNet2DConditionModel, |
| 344 | text_encoder: CLIPTextModel, | 344 | text_encoder: CLIPTextModel, |
| 345 | guidance_scale: float, | ||
| 346 | prior_loss_weight: float, | 345 | prior_loss_weight: float, |
| 347 | seed: int, | 346 | seed: int, |
| 348 | input_pertubation: float, | 347 | input_pertubation: float, |
| @@ -400,19 +399,6 @@ def loss_step( | |||
| 400 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | 399 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False |
| 401 | )[0] | 400 | )[0] |
| 402 | 401 | ||
| 403 | if guidance_scale != 0: | ||
| 404 | uncond_encoder_hidden_states = get_extended_embeddings( | ||
| 405 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] | ||
| 406 | ) | ||
| 407 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
| 408 | |||
| 409 | model_pred_uncond = unet( | ||
| 410 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False | ||
| 411 | )[0] | ||
| 412 | model_pred = model_pred_uncond + guidance_scale * ( | ||
| 413 | model_pred - model_pred_uncond | ||
| 414 | ) | ||
| 415 | |||
| 416 | # Get the target for loss depending on the prediction type | 402 | # Get the target for loss depending on the prediction type |
| 417 | if noise_scheduler.config.prediction_type == "epsilon": | 403 | if noise_scheduler.config.prediction_type == "epsilon": |
| 418 | target = noise | 404 | target = noise |
| @@ -425,7 +411,7 @@ def loss_step( | |||
| 425 | 411 | ||
| 426 | acc = (model_pred == target).float().mean() | 412 | acc = (model_pred == target).float().mean() |
| 427 | 413 | ||
| 428 | if guidance_scale == 0 and prior_loss_weight != 0: | 414 | if prior_loss_weight != 0: |
| 429 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 415 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 430 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 416 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 431 | target, target_prior = torch.chunk(target, 2, dim=0) | 417 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -727,7 +713,6 @@ def train( | |||
| 727 | milestone_checkpoints: bool = True, | 713 | milestone_checkpoints: bool = True, |
| 728 | cycle: int = 1, | 714 | cycle: int = 1, |
| 729 | global_step_offset: int = 0, | 715 | global_step_offset: int = 0, |
| 730 | guidance_scale: float = 0.0, | ||
| 731 | prior_loss_weight: float = 1.0, | 716 | prior_loss_weight: float = 1.0, |
| 732 | input_pertubation: float = 0.1, | 717 | input_pertubation: float = 0.1, |
| 733 | schedule_sampler: Optional[ScheduleSampler] = None, | 718 | schedule_sampler: Optional[ScheduleSampler] = None, |
| @@ -787,7 +772,6 @@ def train( | |||
| 787 | schedule_sampler, | 772 | schedule_sampler, |
| 788 | unet, | 773 | unet, |
| 789 | text_encoder, | 774 | text_encoder, |
| 790 | guidance_scale, | ||
| 791 | prior_loss_weight, | 775 | prior_loss_weight, |
| 792 | seed, | 776 | seed, |
| 793 | input_pertubation, | 777 | input_pertubation, |
