summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py
index 3d27380..7a3e821 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -39,11 +39,18 @@ class TrainingCallbacks():
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[int], None] = const() 40 on_before_optimize: Callable[[int], None] = const()
41 on_after_optimize: Callable[[float], None] = const() 41 on_after_optimize: Callable[[float], None] = const()
42 on_after_epoch: Callable[[float], None] = const()
42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
43 on_sample: Callable[[int], None] = const() 44 on_sample: Callable[[int], None] = const()
44 on_checkpoint: Callable[[int, str], None] = const() 45 on_checkpoint: Callable[[int, str], None] = const()
45 46
46 47
48@dataclass
49class TrainingStrategy():
50 callbacks: Callable[..., TrainingCallbacks]
51 prepare_unet: bool = False
52
53
47def make_grid(images, rows, cols): 54def make_grid(images, rows, cols):
48 w, h = images[0].size 55 w, h = images[0].size
49 grid = Image.new('RGB', size=(cols*w, rows*h)) 56 grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -373,6 +380,7 @@ def train_loop(
373 on_train = callbacks.on_train 380 on_train = callbacks.on_train
374 on_before_optimize = callbacks.on_before_optimize 381 on_before_optimize = callbacks.on_before_optimize
375 on_after_optimize = callbacks.on_after_optimize 382 on_after_optimize = callbacks.on_after_optimize
383 on_after_epoch = callbacks.on_after_epoch
376 on_eval = callbacks.on_eval 384 on_eval = callbacks.on_eval
377 on_sample = callbacks.on_sample 385 on_sample = callbacks.on_sample
378 on_checkpoint = callbacks.on_checkpoint 386 on_checkpoint = callbacks.on_checkpoint
@@ -434,6 +442,8 @@ def train_loop(
434 442
435 accelerator.wait_for_everyone() 443 accelerator.wait_for_everyone()
436 444
445 on_after_epoch(lr_scheduler.get_last_lr()[0])
446
437 if val_dataloader is not None: 447 if val_dataloader is not None:
438 model.eval() 448 model.eval()
439 449
@@ -512,8 +522,7 @@ def train(
512 val_dataloader: Optional[DataLoader], 522 val_dataloader: Optional[DataLoader],
513 optimizer: torch.optim.Optimizer, 523 optimizer: torch.optim.Optimizer,
514 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 524 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
515 callbacks_fn: Callable[..., TrainingCallbacks], 525 strategy: TrainingStrategy,
516 prepare_unet: bool = False,
517 num_train_epochs: int = 100, 526 num_train_epochs: int = 100,
518 sample_frequency: int = 20, 527 sample_frequency: int = 20,
519 checkpoint_frequency: int = 50, 528 checkpoint_frequency: int = 50,
@@ -524,12 +533,12 @@ def train(
524): 533):
525 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] 534 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
526 535
527 if prepare_unet: 536 if strategy.prepare_unet:
528 prep.append(unet) 537 prep.append(unet)
529 538
530 prep = accelerator.prepare(*prep) 539 prep = accelerator.prepare(*prep)
531 540
532 if prepare_unet: 541 if strategy.prepare_unet:
533 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep 542 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
534 else: 543 else:
535 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep 544 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
@@ -542,7 +551,7 @@ def train(
542 model.requires_grad_(False) 551 model.requires_grad_(False)
543 model.eval() 552 model.eval()
544 553
545 callbacks = callbacks_fn( 554 callbacks = strategy.callbacks(
546 accelerator=accelerator, 555 accelerator=accelerator,
547 unet=unet, 556 unet=unet,
548 text_encoder=text_encoder, 557 text_encoder=text_encoder,