diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
| 17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings |
| 18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 19 | from training.util import AverageMeter | 19 | from training.util import AverageMeter |
| 20 | from trainer.base import Checkpointer | ||
| 21 | 20 | ||
| 22 | 21 | ||
| 23 | def const(result=None): | 22 | def const(result=None): |
| @@ -205,7 +204,6 @@ def train_loop( | |||
| 205 | optimizer: torch.optim.Optimizer, | 204 | optimizer: torch.optim.Optimizer, |
| 206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 205 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 207 | model: torch.nn.Module, | 206 | model: torch.nn.Module, |
| 208 | checkpointer: Checkpointer, | ||
| 209 | train_dataloader: DataLoader, | 207 | train_dataloader: DataLoader, |
| 210 | val_dataloader: DataLoader, | 208 | val_dataloader: DataLoader, |
| 211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 209 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| @@ -217,7 +215,9 @@ def train_loop( | |||
| 217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | 215 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), |
| 218 | on_before_optimize: Callable[[int], None] = const(), | 216 | on_before_optimize: Callable[[int], None] = const(), |
| 219 | on_after_optimize: Callable[[float], None] = const(), | 217 | on_after_optimize: Callable[[float], None] = const(), |
| 220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 218 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), |
| 219 | on_sample: Callable[[int], None] = const(), | ||
| 220 | on_checkpoint: Callable[[int, str], None] = const(), | ||
| 221 | ): | 221 | ): |
| 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
| 223 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) |
| @@ -253,10 +253,10 @@ def train_loop( | |||
| 253 | for epoch in range(num_epochs): | 253 | for epoch in range(num_epochs): |
| 254 | if accelerator.is_main_process: | 254 | if accelerator.is_main_process: |
| 255 | if epoch % sample_frequency == 0: | 255 | if epoch % sample_frequency == 0: |
| 256 | checkpointer.save_samples(global_step + global_step_offset) | 256 | on_sample(global_step + global_step_offset) |
| 257 | 257 | ||
| 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 259 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 259 | on_checkpoint(global_step + global_step_offset, "training") |
| 260 | 260 | ||
| 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 262 | local_progress_bar.reset() | 262 | local_progress_bar.reset() |
| @@ -347,19 +347,18 @@ def train_loop( | |||
| 347 | if avg_acc_val.avg.item() > max_acc_val: | 347 | if avg_acc_val.avg.item() > max_acc_val: |
| 348 | accelerator.print( | 348 | accelerator.print( |
| 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 350 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 350 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 351 | max_acc_val = avg_acc_val.avg.item() | 351 | max_acc_val = avg_acc_val.avg.item() |
| 352 | 352 | ||
| 353 | # Create the pipeline using using the trained modules and save it. | 353 | # Create the pipeline using using the trained modules and save it. |
| 354 | if accelerator.is_main_process: | 354 | if accelerator.is_main_process: |
| 355 | print("Finished!") | 355 | print("Finished!") |
| 356 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 356 | on_checkpoint(global_step + global_step_offset, "end") |
| 357 | checkpointer.save_samples(global_step + global_step_offset) | 357 | on_sample(global_step + global_step_offset) |
| 358 | accelerator.end_training() | 358 | accelerator.end_training() |
| 359 | 359 | ||
| 360 | except KeyboardInterrupt: | 360 | except KeyboardInterrupt: |
| 361 | if accelerator.is_main_process: | 361 | if accelerator.is_main_process: |
| 362 | print("Interrupted") | 362 | print("Interrupted") |
| 363 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 363 | on_checkpoint(global_step + global_step_offset, "end") |
| 364 | accelerator.end_training() | 364 | accelerator.end_training() |
| 365 | quit() | ||
