diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 12:03:00 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 12:03:00 +0100 | 
| commit | a5cdb510002324b6e6cf8297ee4cfd6f25330ed2 (patch) | |
| tree | ecde6ce7a7ebbc59b4b05606507b2ecb8299256f /training | |
| parent | Don't rely on Accelerate for gradient accumulation (diff) | |
| download | textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.tar.gz textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.tar.bz2 textual-inversion-diff-a5cdb510002324b6e6cf8297ee4cfd6f25330ed2.zip | |
Fix
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 9 | 
1 files changed, 3 insertions, 6 deletions
| diff --git a/training/functional.py b/training/functional.py index 3f5fa7e..e7c4320 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -375,7 +375,6 @@ def train_loop( | |||
| 375 | num_val_steps = num_val_steps_per_epoch * num_epochs | 375 | num_val_steps = num_val_steps_per_epoch * num_epochs | 
| 376 | 376 | ||
| 377 | global_step = 0 | 377 | global_step = 0 | 
| 378 | train_step = 0 | ||
| 379 | 378 | ||
| 380 | avg_loss = AverageMeter() | 379 | avg_loss = AverageMeter() | 
| 381 | avg_acc = AverageMeter() | 380 | avg_acc = AverageMeter() | 
| @@ -439,11 +438,11 @@ def train_loop( | |||
| 439 | loss, acc, bsz = loss_step(step, batch) | 438 | loss, acc, bsz = loss_step(step, batch) | 
| 440 | loss /= gradient_accumulation_steps | 439 | loss /= gradient_accumulation_steps | 
| 441 | 440 | ||
| 441 | accelerator.backward(loss) | ||
| 442 | |||
| 442 | avg_loss.update(loss.detach_(), bsz) | 443 | avg_loss.update(loss.detach_(), bsz) | 
| 443 | avg_acc.update(acc.detach_(), bsz) | 444 | avg_acc.update(acc.detach_(), bsz) | 
| 444 | 445 | ||
| 445 | accelerator.backward(loss) | ||
| 446 | |||
| 447 | logs = { | 446 | logs = { | 
| 448 | "train/loss": avg_loss.avg.item(), | 447 | "train/loss": avg_loss.avg.item(), | 
| 449 | "train/acc": avg_acc.avg.item(), | 448 | "train/acc": avg_acc.avg.item(), | 
| @@ -455,9 +454,7 @@ def train_loop( | |||
| 455 | 454 | ||
| 456 | local_progress_bar.set_postfix(**logs) | 455 | local_progress_bar.set_postfix(**logs) | 
| 457 | 456 | ||
| 458 | train_step += 1 | 457 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 
| 459 | |||
| 460 | if train_step % gradient_accumulation_steps == 0: | ||
| 461 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 458 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 
| 462 | 459 | ||
| 463 | optimizer.step() | 460 | optimizer.step() | 
