diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -39,11 +39,18 @@ class TrainingCallbacks(): | |||
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[int], None] = const() |
| 41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
| 42 | on_after_epoch: Callable[[float], None] = const() | ||
| 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_sample: Callable[[int], None] = const() | 44 | on_sample: Callable[[int], None] = const() |
| 44 | on_checkpoint: Callable[[int, str], None] = const() | 45 | on_checkpoint: Callable[[int, str], None] = const() |
| 45 | 46 | ||
| 46 | 47 | ||
| 48 | @dataclass | ||
| 49 | class TrainingStrategy(): | ||
| 50 | callbacks: Callable[..., TrainingCallbacks] | ||
| 51 | prepare_unet: bool = False | ||
| 52 | |||
| 53 | |||
| 47 | def make_grid(images, rows, cols): | 54 | def make_grid(images, rows, cols): |
| 48 | w, h = images[0].size | 55 | w, h = images[0].size |
| 49 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 56 | grid = Image.new('RGB', size=(cols*w, rows*h)) |
| @@ -373,6 +380,7 @@ def train_loop( | |||
| 373 | on_train = callbacks.on_train | 380 | on_train = callbacks.on_train |
| 374 | on_before_optimize = callbacks.on_before_optimize | 381 | on_before_optimize = callbacks.on_before_optimize |
| 375 | on_after_optimize = callbacks.on_after_optimize | 382 | on_after_optimize = callbacks.on_after_optimize |
| 383 | on_after_epoch = callbacks.on_after_epoch | ||
| 376 | on_eval = callbacks.on_eval | 384 | on_eval = callbacks.on_eval |
| 377 | on_sample = callbacks.on_sample | 385 | on_sample = callbacks.on_sample |
| 378 | on_checkpoint = callbacks.on_checkpoint | 386 | on_checkpoint = callbacks.on_checkpoint |
| @@ -434,6 +442,8 @@ def train_loop( | |||
| 434 | 442 | ||
| 435 | accelerator.wait_for_everyone() | 443 | accelerator.wait_for_everyone() |
| 436 | 444 | ||
| 445 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | ||
| 446 | |||
| 437 | if val_dataloader is not None: | 447 | if val_dataloader is not None: |
| 438 | model.eval() | 448 | model.eval() |
| 439 | 449 | ||
| @@ -512,8 +522,7 @@ def train( | |||
| 512 | val_dataloader: Optional[DataLoader], | 522 | val_dataloader: Optional[DataLoader], |
| 513 | optimizer: torch.optim.Optimizer, | 523 | optimizer: torch.optim.Optimizer, |
| 514 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 524 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 515 | callbacks_fn: Callable[..., TrainingCallbacks], | 525 | strategy: TrainingStrategy, |
| 516 | prepare_unet: bool = False, | ||
| 517 | num_train_epochs: int = 100, | 526 | num_train_epochs: int = 100, |
| 518 | sample_frequency: int = 20, | 527 | sample_frequency: int = 20, |
| 519 | checkpoint_frequency: int = 50, | 528 | checkpoint_frequency: int = 50, |
| @@ -524,12 +533,12 @@ def train( | |||
| 524 | ): | 533 | ): |
| 525 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 534 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
| 526 | 535 | ||
| 527 | if prepare_unet: | 536 | if strategy.prepare_unet: |
| 528 | prep.append(unet) | 537 | prep.append(unet) |
| 529 | 538 | ||
| 530 | prep = accelerator.prepare(*prep) | 539 | prep = accelerator.prepare(*prep) |
| 531 | 540 | ||
| 532 | if prepare_unet: | 541 | if strategy.prepare_unet: |
| 533 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | 542 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep |
| 534 | else: | 543 | else: |
| 535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | 544 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep |
| @@ -542,7 +551,7 @@ def train( | |||
| 542 | model.requires_grad_(False) | 551 | model.requires_grad_(False) |
| 543 | model.eval() | 552 | model.eval() |
| 544 | 553 | ||
| 545 | callbacks = callbacks_fn( | 554 | callbacks = strategy.callbacks( |
| 546 | accelerator=accelerator, | 555 | accelerator=accelerator, |
| 547 | unet=unet, | 556 | unet=unet, |
| 548 | text_encoder=text_encoder, | 557 | text_encoder=text_encoder, |
