diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 53 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 6 |
2 files changed, 29 insertions, 30 deletions
diff --git a/training/functional.py b/training/functional.py index 739d055..3f5fa7e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -365,15 +365,17 @@ def train_loop( | |||
365 | milestone_checkpoints: bool = True, | 365 | milestone_checkpoints: bool = True, |
366 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
367 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
368 | gradient_accumulation_steps: int = 1, | ||
368 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 369 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
369 | ): | 370 | ): |
370 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 371 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
371 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 | 372 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
372 | 373 | ||
373 | num_training_steps = num_training_steps_per_epoch * num_epochs | 374 | num_training_steps = num_training_steps_per_epoch * num_epochs |
374 | num_val_steps = num_val_steps_per_epoch * num_epochs | 375 | num_val_steps = num_val_steps_per_epoch * num_epochs |
375 | 376 | ||
376 | global_step = 0 | 377 | global_step = 0 |
378 | train_step = 0 | ||
377 | 379 | ||
378 | avg_loss = AverageMeter() | 380 | avg_loss = AverageMeter() |
379 | avg_acc = AverageMeter() | 381 | avg_acc = AverageMeter() |
@@ -434,44 +436,45 @@ def train_loop( | |||
434 | 436 | ||
435 | with on_train(epoch): | 437 | with on_train(epoch): |
436 | for step, batch in enumerate(train_dataloader): | 438 | for step, batch in enumerate(train_dataloader): |
437 | with accelerator.accumulate(model): | 439 | loss, acc, bsz = loss_step(step, batch) |
438 | loss, acc, bsz = loss_step(step, batch) | 440 | loss /= gradient_accumulation_steps |
439 | 441 | ||
440 | accelerator.backward(loss) | 442 | avg_loss.update(loss.detach_(), bsz) |
443 | avg_acc.update(acc.detach_(), bsz) | ||
441 | 444 | ||
445 | accelerator.backward(loss) | ||
446 | |||
447 | logs = { | ||
448 | "train/loss": avg_loss.avg.item(), | ||
449 | "train/acc": avg_acc.avg.item(), | ||
450 | "train/cur_loss": loss.item(), | ||
451 | "train/cur_acc": acc.item(), | ||
452 | "lr": lr_scheduler.get_last_lr()[0], | ||
453 | } | ||
454 | logs.update(on_log()) | ||
455 | |||
456 | local_progress_bar.set_postfix(**logs) | ||
457 | |||
458 | train_step += 1 | ||
459 | |||
460 | if train_step % gradient_accumulation_steps == 0: | ||
442 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) | 461 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
443 | 462 | ||
444 | optimizer.step() | 463 | optimizer.step() |
445 | lr_scheduler.step() | 464 | lr_scheduler.step() |
446 | optimizer.zero_grad(set_to_none=True) | 465 | optimizer.zero_grad(set_to_none=True) |
447 | 466 | ||
448 | avg_loss.update(loss.detach_(), bsz) | ||
449 | avg_acc.update(acc.detach_(), bsz) | ||
450 | |||
451 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
452 | if accelerator.sync_gradients: | ||
453 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 467 | on_after_optimize(lr_scheduler.get_last_lr()[0]) |
454 | 468 | ||
455 | local_progress_bar.update(1) | 469 | local_progress_bar.update(1) |
456 | global_progress_bar.update(1) | 470 | global_progress_bar.update(1) |
457 | 471 | ||
458 | global_step += 1 | 472 | accelerator.log(logs, step=global_step) |
459 | 473 | ||
460 | logs = { | 474 | global_step += 1 |
461 | "train/loss": avg_loss.avg.item(), | ||
462 | "train/acc": avg_acc.avg.item(), | ||
463 | "train/cur_loss": loss.item(), | ||
464 | "train/cur_acc": acc.item(), | ||
465 | "lr": lr_scheduler.get_last_lr()[0], | ||
466 | } | ||
467 | logs.update(on_log()) | ||
468 | |||
469 | accelerator.log(logs, step=global_step) | ||
470 | |||
471 | local_progress_bar.set_postfix(**logs) | ||
472 | 475 | ||
473 | if global_step >= num_training_steps: | 476 | if global_step >= num_training_steps: |
474 | break | 477 | break |
475 | 478 | ||
476 | accelerator.wait_for_everyone() | 479 | accelerator.wait_for_everyone() |
477 | 480 | ||
@@ -571,6 +574,7 @@ def train( | |||
571 | strategy: TrainingStrategy, | 574 | strategy: TrainingStrategy, |
572 | no_val: bool = False, | 575 | no_val: bool = False, |
573 | num_train_epochs: int = 100, | 576 | num_train_epochs: int = 100, |
577 | gradient_accumulation_steps: int = 1, | ||
574 | sample_frequency: int = 20, | 578 | sample_frequency: int = 20, |
575 | checkpoint_frequency: int = 50, | 579 | checkpoint_frequency: int = 50, |
576 | milestone_checkpoints: bool = True, | 580 | milestone_checkpoints: bool = True, |
@@ -631,6 +635,7 @@ def train( | |||
631 | milestone_checkpoints=milestone_checkpoints, | 635 | milestone_checkpoints=milestone_checkpoints, |
632 | global_step_offset=global_step_offset, | 636 | global_step_offset=global_step_offset, |
633 | num_epochs=num_train_epochs, | 637 | num_epochs=num_train_epochs, |
638 | gradient_accumulation_steps=gradient_accumulation_steps, | ||
634 | callbacks=callbacks, | 639 | callbacks=callbacks, |
635 | ) | 640 | ) |
636 | 641 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d697554..fcf5c0d 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -41,12 +41,6 @@ def dreambooth_strategy_callbacks( | |||
41 | sample_guidance_scale: float = 7.5, | 41 | sample_guidance_scale: float = 7.5, |
42 | sample_image_size: Optional[int] = None, | 42 | sample_image_size: Optional[int] = None, |
43 | ): | 43 | ): |
44 | if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
45 | raise ValueError( | ||
46 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
47 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
48 | ) | ||
49 | |||
50 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 44 | sample_output_dir.mkdir(parents=True, exist_ok=True) |
51 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 45 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) |
52 | 46 | ||