summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 68ea40c..38dd59f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -700,8 +700,11 @@ def train(
700 vae.requires_grad_(False) 700 vae.requires_grad_(False)
701 vae.eval() 701 vae.eval()
702 702
703 vae = torch.compile(vae, backend='hidet')
704
703 if compile_unet: 705 if compile_unet:
704 unet = torch.compile(unet, backend='hidet') 706 unet = torch.compile(unet, backend='hidet')
707 # unet = torch.compile(unet)
705 708
706 callbacks = strategy.callbacks( 709 callbacks = strategy.callbacks(
707 accelerator=accelerator, 710 accelerator=accelerator,