summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 11:50:11 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 11:50:11 +0100
commit9d6252e63bac241e5c6191eb47adb51b84a5d782 (patch)
tree6cb649510b48ca33419af3721e630f1c06bf1ae2 /training/functional.py
parentEmbedding normalization: Ignore tensors with grad = 0 (diff)
downloadtextual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.gz
textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.bz2
textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.zip
Don't rely on Accelerate for gradient accumulation
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py53
1 files changed, 29 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py
index 739d055..3f5fa7e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -365,15 +365,17 @@ def train_loop(
365 milestone_checkpoints: bool = True, 365 milestone_checkpoints: bool = True,
366 global_step_offset: int = 0, 366 global_step_offset: int = 0,
367 num_epochs: int = 100, 367 num_epochs: int = 100,
368 gradient_accumulation_steps: int = 1,
368 callbacks: TrainingCallbacks = TrainingCallbacks(), 369 callbacks: TrainingCallbacks = TrainingCallbacks(),
369): 370):
370 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 371 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
371 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 372 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0
372 373
373 num_training_steps = num_training_steps_per_epoch * num_epochs 374 num_training_steps = num_training_steps_per_epoch * num_epochs
374 num_val_steps = num_val_steps_per_epoch * num_epochs 375 num_val_steps = num_val_steps_per_epoch * num_epochs
375 376
376 global_step = 0 377 global_step = 0
378 train_step = 0
377 379
378 avg_loss = AverageMeter() 380 avg_loss = AverageMeter()
379 avg_acc = AverageMeter() 381 avg_acc = AverageMeter()
@@ -434,44 +436,45 @@ def train_loop(
434 436
435 with on_train(epoch): 437 with on_train(epoch):
436 for step, batch in enumerate(train_dataloader): 438 for step, batch in enumerate(train_dataloader):
437 with accelerator.accumulate(model): 439 loss, acc, bsz = loss_step(step, batch)
438 loss, acc, bsz = loss_step(step, batch) 440 loss /= gradient_accumulation_steps
439 441
440 accelerator.backward(loss) 442 avg_loss.update(loss.detach_(), bsz)
443 avg_acc.update(acc.detach_(), bsz)
441 444
445 accelerator.backward(loss)
446
447 logs = {
448 "train/loss": avg_loss.avg.item(),
449 "train/acc": avg_acc.avg.item(),
450 "train/cur_loss": loss.item(),
451 "train/cur_acc": acc.item(),
452 "lr": lr_scheduler.get_last_lr()[0],
453 }
454 logs.update(on_log())
455
456 local_progress_bar.set_postfix(**logs)
457
458 train_step += 1
459
460 if train_step % gradient_accumulation_steps == 0:
442 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) 461 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
443 462
444 optimizer.step() 463 optimizer.step()
445 lr_scheduler.step() 464 lr_scheduler.step()
446 optimizer.zero_grad(set_to_none=True) 465 optimizer.zero_grad(set_to_none=True)
447 466
448 avg_loss.update(loss.detach_(), bsz)
449 avg_acc.update(acc.detach_(), bsz)
450
451 # Checks if the accelerator has performed an optimization step behind the scenes
452 if accelerator.sync_gradients:
453 on_after_optimize(lr_scheduler.get_last_lr()[0]) 467 on_after_optimize(lr_scheduler.get_last_lr()[0])
454 468
455 local_progress_bar.update(1) 469 local_progress_bar.update(1)
456 global_progress_bar.update(1) 470 global_progress_bar.update(1)
457 471
458 global_step += 1 472 accelerator.log(logs, step=global_step)
459 473
460 logs = { 474 global_step += 1
461 "train/loss": avg_loss.avg.item(),
462 "train/acc": avg_acc.avg.item(),
463 "train/cur_loss": loss.item(),
464 "train/cur_acc": acc.item(),
465 "lr": lr_scheduler.get_last_lr()[0],
466 }
467 logs.update(on_log())
468
469 accelerator.log(logs, step=global_step)
470
471 local_progress_bar.set_postfix(**logs)
472 475
473 if global_step >= num_training_steps: 476 if global_step >= num_training_steps:
474 break 477 break
475 478
476 accelerator.wait_for_everyone() 479 accelerator.wait_for_everyone()
477 480
@@ -571,6 +574,7 @@ def train(
571 strategy: TrainingStrategy, 574 strategy: TrainingStrategy,
572 no_val: bool = False, 575 no_val: bool = False,
573 num_train_epochs: int = 100, 576 num_train_epochs: int = 100,
577 gradient_accumulation_steps: int = 1,
574 sample_frequency: int = 20, 578 sample_frequency: int = 20,
575 checkpoint_frequency: int = 50, 579 checkpoint_frequency: int = 50,
576 milestone_checkpoints: bool = True, 580 milestone_checkpoints: bool = True,
@@ -631,6 +635,7 @@ def train(
631 milestone_checkpoints=milestone_checkpoints, 635 milestone_checkpoints=milestone_checkpoints,
632 global_step_offset=global_step_offset, 636 global_step_offset=global_step_offset,
633 num_epochs=num_train_epochs, 637 num_epochs=num_train_epochs,
638 gradient_accumulation_steps=gradient_accumulation_steps,
634 callbacks=callbacks, 639 callbacks=callbacks,
635 ) 640 )
636 641