summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py75
1 files changed, 74 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index c5b514a..1f2ca6d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,6 +1,7 @@
1import math 1import math
2from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 3from typing import Callable, Any, Tuple, Union, Optional
4from functools import partial
4 5
5import torch 6import torch
6import torch.nn.functional as F 7import torch.nn.functional as F
@@ -376,3 +377,75 @@ def train_loop(
376 print("Interrupted") 377 print("Interrupted")
377 on_checkpoint(global_step + global_step_offset, "end") 378 on_checkpoint(global_step + global_step_offset, "end")
378 accelerator.end_training() 379 accelerator.end_training()
380
381
382def train(
383 accelerator: Accelerator,
384 unet: UNet2DConditionModel,
385 text_encoder: CLIPTextModel,
386 vae: AutoencoderKL,
387 noise_scheduler: DDPMScheduler,
388 train_dataloader: DataLoader,
389 val_dataloader: DataLoader,
390 dtype: torch.dtype,
391 seed: int,
392 optimizer: torch.optim.Optimizer,
393 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
394 num_train_epochs: int = 100,
395 sample_frequency: int = 20,
396 checkpoint_frequency: int = 50,
397 global_step_offset: int = 0,
398 prior_loss_weight: float = 0,
399 on_prepare: Callable[[], dict[str, Any]] = const({}),
400 on_log: Callable[[], dict[str, Any]] = const({}),
401 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
402 on_before_optimize: Callable[[int], None] = const(),
403 on_after_optimize: Callable[[float], None] = const(),
404 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
405 on_sample: Callable[[int], None] = const(),
406 on_checkpoint: Callable[[int, str], None] = const(),
407):
408 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
409 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
410 )
411
412 vae.to(accelerator.device, dtype=dtype)
413
414 for model in (unet, text_encoder, vae):
415 model.requires_grad_(False)
416 model.eval()
417
418 on_prepare()
419
420 loss_step_ = partial(
421 loss_step,
422 vae,
423 noise_scheduler,
424 unet,
425 text_encoder,
426 prior_loss_weight,
427 seed,
428 )
429
430 train_loop(
431 accelerator=accelerator,
432 optimizer=optimizer,
433 lr_scheduler=lr_scheduler,
434 model=text_encoder,
435 train_dataloader=train_dataloader,
436 val_dataloader=val_dataloader,
437 loss_step=loss_step_,
438 sample_frequency=sample_frequency,
439 checkpoint_frequency=checkpoint_frequency,
440 global_step_offset=global_step_offset,
441 num_epochs=num_train_epochs,
442 on_log=on_log,
443 on_train=on_train,
444 on_before_optimize=on_before_optimize,
445 on_after_optimize=on_after_optimize,
446 on_eval=on_eval,
447 on_sample=on_sample,
448 on_checkpoint=on_checkpoint,
449 )
450
451 accelerator.free_memory()