diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-21 11:50:11 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-21 11:50:11 +0100 |
| commit | 9d6252e63bac241e5c6191eb47adb51b84a5d782 (patch) | |
| tree | 6cb649510b48ca33419af3721e630f1c06bf1ae2 | |
| parent | Embedding normalization: Ignore tensors with grad = 0 (diff) | |
| download | textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.gz textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.tar.bz2 textual-inversion-diff-9d6252e63bac241e5c6191eb47adb51b84a5d782.zip | |
Don't rely on Accelerate for gradient accumulation
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 2 | ||||
| -rw-r--r-- | training/functional.py | 53 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 6 |
5 files changed, 32 insertions, 33 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 431ff3d..280cf77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -439,7 +439,6 @@ def main(): | |||
| 439 | accelerator = Accelerator( | 439 | accelerator = Accelerator( |
| 440 | log_with=LoggerType.TENSORBOARD, | 440 | log_with=LoggerType.TENSORBOARD, |
| 441 | logging_dir=f"{output_dir}", | 441 | logging_dir=f"{output_dir}", |
| 442 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 443 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
| 444 | ) | 443 | ) |
| 445 | 444 | ||
| @@ -590,6 +589,7 @@ def main(): | |||
| 590 | lr_scheduler=lr_scheduler, | 589 | lr_scheduler=lr_scheduler, |
| 591 | prepare_unet=True, | 590 | prepare_unet=True, |
| 592 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
| 592 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 593 | sample_frequency=args.sample_frequency, | 593 | sample_frequency=args.sample_frequency, |
| 594 | # -- | 594 | # -- |
| 595 | tokenizer=tokenizer, | 595 | tokenizer=tokenizer, |
diff --git a/train_lora.py b/train_lora.py index a06591d..d7c2de0 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -399,7 +399,6 @@ def main(): | |||
| 399 | accelerator = Accelerator( | 399 | accelerator = Accelerator( |
| 400 | log_with=LoggerType.TENSORBOARD, | 400 | log_with=LoggerType.TENSORBOARD, |
| 401 | logging_dir=f"{output_dir}", | 401 | logging_dir=f"{output_dir}", |
| 402 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 403 | mixed_precision=args.mixed_precision | 402 | mixed_precision=args.mixed_precision |
| 404 | ) | 403 | ) |
| 405 | 404 | ||
| @@ -561,6 +560,7 @@ def main(): | |||
| 561 | optimizer=optimizer, | 560 | optimizer=optimizer, |
| 562 | lr_scheduler=lr_scheduler, | 561 | lr_scheduler=lr_scheduler, |
| 563 | num_train_epochs=args.num_train_epochs, | 562 | num_train_epochs=args.num_train_epochs, |
| 563 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 564 | sample_frequency=args.sample_frequency, | 564 | sample_frequency=args.sample_frequency, |
| 565 | # -- | 565 | # -- |
| 566 | tokenizer=tokenizer, | 566 | tokenizer=tokenizer, |
diff --git a/train_ti.py b/train_ti.py index 6dc07dd..68783ea 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -518,7 +518,6 @@ def main(): | |||
| 518 | accelerator = Accelerator( | 518 | accelerator = Accelerator( |
| 519 | log_with=LoggerType.TENSORBOARD, | 519 | log_with=LoggerType.TENSORBOARD, |
| 520 | logging_dir=f"{output_dir}", | 520 | logging_dir=f"{output_dir}", |
| 521 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 522 | mixed_precision=args.mixed_precision | 521 | mixed_precision=args.mixed_precision |
| 523 | ) | 522 | ) |
| 524 | 523 | ||
| @@ -611,6 +610,7 @@ def main(): | |||
| 611 | low_freq_noise=0, | 610 | low_freq_noise=0, |
| 612 | strategy=textual_inversion_strategy, | 611 | strategy=textual_inversion_strategy, |
| 613 | num_train_epochs=args.num_train_epochs, | 612 | num_train_epochs=args.num_train_epochs, |
| 613 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 614 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
| 615 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
| 616 | milestone_checkpoints=not args.no_milestone_checkpoints, | 616 | milestone_checkpoints=not args.no_milestone_checkpoints, |
diff --git a/training/functional.py b/training/functional.py index 739d055..3f5fa7e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -365,15 +365,17 @@ def train_loop( | |||
| 365 | milestone_checkpoints: bool = True, | 365 | milestone_checkpoints: bool = True, |
| 366 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
| 367 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
| 368 | gradient_accumulation_steps: int = 1, | ||
| 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 369 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 369 | ): | 370 | ): |
| 370 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 371 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| 371 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 372 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
| 372 | 373 | ||
| 373 | num_training_steps = num_training_steps_per_epoch * num_epochs | 374 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 374 | num_val_steps = num_val_steps_per_epoch * num_epochs | 375 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| 375 | 376 | ||
| 376 | global_step = 0 | 377 | global_step = 0 |
| 378 | train_step = 0 | ||
| 377 | 379 | ||
| 378 | avg_loss = AverageMeter() | 380 | avg_loss = AverageMeter() |
| 379 | avg_acc = AverageMeter() | 381 | avg_acc = AverageMeter() |
| @@ -434,44 +436,45 @@ def train_loop( | |||
| 434 | 436 | ||
| 435 | with on_train(epoch): | 437 | with on_train(epoch): |
| 436 | for step, batch in enumerate(train_dataloader): | 438 | for step, batch in enumerate(train_dataloader): |
| 437 | with accelerator.accumulate(model): | 439 | loss, acc, bsz = loss_step(step, batch) |
| 438 | loss, acc, bsz = loss_step(step, batch) | 440 | loss /= gradient_accumulation_steps |
| 439 | 441 | ||
| 440 | accelerator.backward(loss) | 442 | avg_loss.update(loss.detach_(), bsz) |
| 443 | avg_acc.update(acc.detach_(), bsz) | ||
| 441 | 444 | ||
| 445 | accelerator.backward(loss) | ||
| 446 | |||
| 447 | logs = { | ||
| 448 | "train/loss": avg_loss.avg.item(), | ||
| 449 | "train/acc": avg_acc.avg.item(), | ||
| 450 | "train/cur_loss": loss.item(), | ||
| 451 | "train/cur_acc": acc.item(), | ||
| 452 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 453 | } | ||
| 454 | logs.update(on_log()) | ||
| 455 | |||
| 456 | local_progress_bar.set_postfix(**logs) | ||
| 457 | |||
| 458 | train_step += 1 | ||
| 459 | |||
| 460 | if train_step % gradient_accumulation_steps == 0: | ||
| 442 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 461 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
| 443 | 462 | ||
| 444 | optimizer.step() | 463 | optimizer.step() |
| 445 | lr_scheduler.step() | 464 | lr_scheduler.step() |
| 446 | optimizer.zero_grad(set_to_none=True) | 465 | optimizer.zero_grad(set_to_none=True) |
| 447 | 466 | ||
| 448 | avg_loss.update(loss.detach_(), bsz) | ||
| 449 | avg_acc.update(acc.detach_(), bsz) | ||
| 450 | |||
| 451 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
| 452 | if accelerator.sync_gradients: | ||
| 453 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 467 | on_after_optimize(lr_scheduler.get_last_lr()[0]) |
| 454 | 468 | ||
| 455 | local_progress_bar.update(1) | 469 | local_progress_bar.update(1) |
| 456 | global_progress_bar.update(1) | 470 | global_progress_bar.update(1) |
| 457 | 471 | ||
| 458 | global_step += 1 | 472 | accelerator.log(logs, step=global_step) |
| 459 | 473 | ||
| 460 | logs = { | 474 | global_step += 1 |
| 461 | "train/loss": avg_loss.avg.item(), | ||
| 462 | "train/acc": avg_acc.avg.item(), | ||
| 463 | "train/cur_loss": loss.item(), | ||
| 464 | "train/cur_acc": acc.item(), | ||
| 465 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 466 | } | ||
| 467 | logs.update(on_log()) | ||
| 468 | |||
| 469 | accelerator.log(logs, step=global_step) | ||
| 470 | |||
| 471 | local_progress_bar.set_postfix(**logs) | ||
| 472 | 475 | ||
| 473 | if global_step >= num_training_steps: | 476 | if global_step >= num_training_steps: |
| 474 | break | 477 | break |
| 475 | 478 | ||
| 476 | accelerator.wait_for_everyone() | 479 | accelerator.wait_for_everyone() |
| 477 | 480 | ||
| @@ -571,6 +574,7 @@ def train( | |||
| 571 | strategy: TrainingStrategy, | 574 | strategy: TrainingStrategy, |
| 572 | no_val: bool = False, | 575 | no_val: bool = False, |
| 573 | num_train_epochs: int = 100, | 576 | num_train_epochs: int = 100, |
| 577 | gradient_accumulation_steps: int = 1, | ||
| 574 | sample_frequency: int = 20, | 578 | sample_frequency: int = 20, |
| 575 | checkpoint_frequency: int = 50, | 579 | checkpoint_frequency: int = 50, |
| 576 | milestone_checkpoints: bool = True, | 580 | milestone_checkpoints: bool = True, |
| @@ -631,6 +635,7 @@ def train( | |||
| 631 | milestone_checkpoints=milestone_checkpoints, | 635 | milestone_checkpoints=milestone_checkpoints, |
| 632 | global_step_offset=global_step_offset, | 636 | global_step_offset=global_step_offset, |
| 633 | num_epochs=num_train_epochs, | 637 | num_epochs=num_train_epochs, |
| 638 | gradient_accumulation_steps=gradient_accumulation_steps, | ||
| 634 | callbacks=callbacks, | 639 | callbacks=callbacks, |
| 635 | ) | 640 | ) |
| 636 | 641 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d697554..fcf5c0d 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -41,12 +41,6 @@ def dreambooth_strategy_callbacks( | |||
| 41 | sample_guidance_scale: float = 7.5, | 41 | sample_guidance_scale: float = 7.5, |
| 42 | sample_image_size: Optional[int] = None, | 42 | sample_image_size: Optional[int] = None, |
| 43 | ): | 43 | ): |
| 44 | if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
| 45 | raise ValueError( | ||
| 46 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
| 47 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
| 48 | ) | ||
| 49 | |||
| 50 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 44 | sample_output_dir.mkdir(parents=True, exist_ok=True) |
| 51 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 45 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) |
| 52 | 46 | ||
