summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:35:42 +0100
commit6c38d0088ece492696a7bc94a5cb43a48289452a (patch)
treed84a8fefd52eba5cbf38e64d34962f34dc6d047d /training/common.py
parentCleanup (diff)
downloadtextual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.gz
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.tar.bz2
textual-inversion-diff-6c38d0088ece492696a7bc94a5cb43a48289452a.zip
Fix
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/training/common.py b/training/common.py
index f5ab326..8083137 100644
--- a/training/common.py
+++ b/training/common.py
@@ -184,7 +184,7 @@ def loss_step(
184 else: 184 else:
185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
186 186
187 if batch["with_prior"]: 187 if batch["with_prior"].all():
188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
190 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)