diff options
author | Volpeon <git@volpeon.ink> | 2023-02-21 12:03:00 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-21 12:03:00 +0100 |
commit | a5cdb510002324b6e6cf8297ee4cfd6f25330ed2 (patch) | |
tree | ecde6ce7a7ebbc59b4b05606507b2ecb8299256f /training | |
parent | Don't rely on Accelerate for gradient accumulation (diff) | |
download | textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.tar.gz textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.tar.bz2 textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.zip |
Fix
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 3f5fa7e..e7c4320 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -375,7 +375,6 @@ def train_loop( | |||
375 | num_val_steps = num_val_steps_per_epoch * num_epochs | 375 | num_val_steps = num_val_steps_per_epoch * num_epochs |
376 | 376 | ||
377 | global_step = 0 | 377 | global_step = 0 |
378 | train_step = 0 | ||
379 | 378 | ||
380 | avg_loss = AverageMeter() | 379 | avg_loss = AverageMeter() |
381 | avg_acc = AverageMeter() | 380 | avg_acc = AverageMeter() |
@@ -439,11 +438,11 @@ def train_loop( | |||
439 | loss, acc, bsz = loss_step(step, batch) | 438 | loss, acc, bsz = loss_step(step, batch) |
440 | loss /= gradient_accumulation_steps | 439 | loss /= gradient_accumulation_steps |
441 | 440 | ||
441 | accelerator.backward(loss) | ||
442 | |||
442 | avg_loss.update(loss.detach_(), bsz) | 443 | avg_loss.update(loss.detach_(), bsz) |
443 | avg_acc.update(acc.detach_(), bsz) | 444 | avg_acc.update(acc.detach_(), bsz) |
444 | 445 | ||
445 | accelerator.backward(loss) | ||
446 | |||
447 | logs = { | 446 | logs = { |
448 | "train/loss": avg_loss.avg.item(), | 447 | "train/loss": avg_loss.avg.item(), |
449 | "train/acc": avg_acc.avg.item(), | 448 | "train/acc": avg_acc.avg.item(), |
@@ -455,9 +454,7 @@ def train_loop( | |||
455 | 454 | ||
456 | local_progress_bar.set_postfix(**logs) | 455 | local_progress_bar.set_postfix(**logs) |
457 | 456 | ||
458 | train_step += 1 | 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): |
459 | |||
460 | if train_step % gradient_accumulation_steps == 0: | ||
461 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 458 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
462 | 459 | ||
463 | optimizer.step() | 460 | optimizer.step() |