From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- training/functional.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py @@ -39,11 +39,18 @@ class TrainingCallbacks(): on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() on_after_optimize: Callable[[float], None] = const() + on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() +@dataclass +class TrainingStrategy(): + callbacks: Callable[..., TrainingCallbacks] + prepare_unet: bool = False + + def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) @@ -373,6 +380,7 @@ def train_loop( on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize + on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint @@ -434,6 +442,8 @@ def train_loop( accelerator.wait_for_everyone() + on_after_epoch(lr_scheduler.get_last_lr()[0]) + if val_dataloader is not None: model.eval() @@ -512,8 +522,7 @@ def train( val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - callbacks_fn: Callable[..., TrainingCallbacks], - prepare_unet: bool = False, + strategy: TrainingStrategy, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -524,12 +533,12 @@ def train( ): prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] - if prepare_unet: + if strategy.prepare_unet: prep.append(unet) prep = accelerator.prepare(*prep) - if prepare_unet: + if strategy.prepare_unet: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep else: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep @@ -542,7 +551,7 @@ def train( model.requires_grad_(False) model.eval() - callbacks = callbacks_fn( + callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, -- cgit v1.2.3-54-g00ecf