summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index e7cc20f..68ea40c 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -672,6 +672,7 @@ def train(
672 optimizer: torch.optim.Optimizer, 672 optimizer: torch.optim.Optimizer,
673 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 673 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
674 strategy: TrainingStrategy, 674 strategy: TrainingStrategy,
675 compile_unet: bool = False,
675 no_val: bool = False, 676 no_val: bool = False,
676 num_train_epochs: int = 100, 677 num_train_epochs: int = 100,
677 gradient_accumulation_steps: int = 1, 678 gradient_accumulation_steps: int = 1,
@@ -699,9 +700,8 @@ def train(
699 vae.requires_grad_(False) 700 vae.requires_grad_(False)
700 vae.eval() 701 vae.eval()
701 702
702 unet = torch.compile(unet) 703 if compile_unet:
703 text_encoder = torch.compile(text_encoder) 704 unet = torch.compile(unet, backend='hidet')
704 vae = torch.compile(vae)
705 705
706 callbacks = strategy.callbacks( 706 callbacks = strategy.callbacks(
707 accelerator=accelerator, 707 accelerator=accelerator,