diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 13:26:32 +0100 |
| commit | 3396ca881ed3f3521617cd9024eea56975191d32 (patch) | |
| tree | 3189c3bbe77b211152d11b524d0fe3a7016441ee /training | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.gz textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.bz2 textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.zip | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/common.py b/training/common.py index 99a6e67..ab2741a 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -40,7 +40,10 @@ def run_model( | |||
| 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 41 | 41 | ||
| 42 | # Get the text embedding for conditioning | 42 | # Get the text embedding for conditioning |
| 43 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | 43 | encoder_hidden_states = prompt_processor.get_embeddings( |
| 44 | batch["input_ids"], | ||
| 45 | batch["attention_mask"] | ||
| 46 | ) | ||
| 44 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 47 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
| 45 | 48 | ||
| 46 | # Predict the noise residual | 49 | # Predict the noise residual |
