diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings |
18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer |
19 | from training.util import AverageMeter | 19 | from training.util import AverageMeter |
20 | from trainer.base import Checkpointer | ||
21 | 20 | ||
22 | 21 | ||
23 | def const(result=None): | 22 | def const(result=None): |
@@ -205,7 +204,6 @@ def train_loop( | |||
205 | optimizer: torch.optim.Optimizer, | 204 | optimizer: torch.optim.Optimizer, |
206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 205 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
207 | model: torch.nn.Module, | 206 | model: torch.nn.Module, |
208 | checkpointer: Checkpointer, | ||
209 | train_dataloader: DataLoader, | 207 | train_dataloader: DataLoader, |
210 | val_dataloader: DataLoader, | 208 | val_dataloader: DataLoader, |
211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 209 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
@@ -217,7 +215,9 @@ def train_loop( | |||
217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | 215 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), |
218 | on_before_optimize: Callable[[int], None] = const(), | 216 | on_before_optimize: Callable[[int], None] = const(), |
219 | on_after_optimize: Callable[[float], None] = const(), | 217 | on_after_optimize: Callable[[float], None] = const(), |
220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 218 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), |
219 | on_sample: Callable[[int], None] = const(), | ||
220 | on_checkpoint: Callable[[int, str], None] = const(), | ||
221 | ): | 221 | ): |
222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
223 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) |
@@ -253,10 +253,10 @@ def train_loop( | |||
253 | for epoch in range(num_epochs): | 253 | for epoch in range(num_epochs): |
254 | if accelerator.is_main_process: | 254 | if accelerator.is_main_process: |
255 | if epoch % sample_frequency == 0: | 255 | if epoch % sample_frequency == 0: |
256 | checkpointer.save_samples(global_step + global_step_offset) | 256 | on_sample(global_step + global_step_offset) |
257 | 257 | ||
258 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
259 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 259 | on_checkpoint(global_step + global_step_offset, "training") |
260 | 260 | ||
261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
262 | local_progress_bar.reset() | 262 | local_progress_bar.reset() |
@@ -347,19 +347,18 @@ def train_loop( | |||
347 | if avg_acc_val.avg.item() > max_acc_val: | 347 | if avg_acc_val.avg.item() > max_acc_val: |
348 | accelerator.print( | 348 | accelerator.print( |
349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
350 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 350 | on_checkpoint(global_step + global_step_offset, "milestone") |
351 | max_acc_val = avg_acc_val.avg.item() | 351 | max_acc_val = avg_acc_val.avg.item() |
352 | 352 | ||
353 | # Create the pipeline using using the trained modules and save it. | 353 | # Create the pipeline using using the trained modules and save it. |
354 | if accelerator.is_main_process: | 354 | if accelerator.is_main_process: |
355 | print("Finished!") | 355 | print("Finished!") |
356 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 356 | on_checkpoint(global_step + global_step_offset, "end") |
357 | checkpointer.save_samples(global_step + global_step_offset) | 357 | on_sample(global_step + global_step_offset) |
358 | accelerator.end_training() | 358 | accelerator.end_training() |
359 | 359 | ||
360 | except KeyboardInterrupt: | 360 | except KeyboardInterrupt: |
361 | if accelerator.is_main_process: | 361 | if accelerator.is_main_process: |
362 | print("Interrupted") | 362 | print("Interrupted") |
363 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 363 | on_checkpoint(global_step + global_step_offset, "end") |
364 | accelerator.end_training() | 364 | accelerator.end_training() |
365 | quit() | ||