From 9d6252e63bac241e5c6191eb47adb51b84a5d782 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 21 Feb 2023 11:50:11 +0100 Subject: Don't rely on Accelerate for gradient accumulation --- training/functional.py | 53 ++++++++++++++++++++++------------------- training/strategy/dreambooth.py | 6 ----- 2 files changed, 29 insertions(+), 30 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 739d055..3f5fa7e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -365,15 +365,17 @@ def train_loop( milestone_checkpoints: bool = True, global_step_offset: int = 0, num_epochs: int = 100, + gradient_accumulation_steps: int = 1, callbacks: TrainingCallbacks = TrainingCallbacks(), ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 + train_step = 0 avg_loss = AverageMeter() avg_acc = AverageMeter() @@ -434,44 +436,45 @@ def train_loop( with on_train(epoch): for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(model): - loss, acc, bsz = loss_step(step, batch) + loss, acc, bsz = loss_step(step, batch) + loss /= gradient_accumulation_steps - accelerator.backward(loss) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) + accelerator.backward(loss) + + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + logs.update(on_log()) + + local_progress_bar.set_postfix(**logs) + + train_step += 1 + + if train_step % gradient_accumulation_steps == 0: on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) - global_step += 1 + accelerator.log(logs, step=global_step) - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - logs.update(on_log()) - - accelerator.log(logs, step=global_step) - - local_progress_bar.set_postfix(**logs) + global_step += 1 - if global_step >= num_training_steps: - break + if global_step >= num_training_steps: + break accelerator.wait_for_everyone() @@ -571,6 +574,7 @@ def train( strategy: TrainingStrategy, no_val: bool = False, num_train_epochs: int = 100, + gradient_accumulation_steps: int = 1, sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, @@ -631,6 +635,7 @@ def train( milestone_checkpoints=milestone_checkpoints, global_step_offset=global_step_offset, num_epochs=num_train_epochs, + gradient_accumulation_steps=gradient_accumulation_steps, callbacks=callbacks, ) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d697554..fcf5c0d 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -41,12 +41,6 @@ def dreambooth_strategy_callbacks( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): - if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - sample_output_dir.mkdir(parents=True, exist_ok=True) checkpoint_output_dir.mkdir(parents=True, exist_ok=True) -- cgit v1.2.3-70-g09d2