diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 68ea40c..38dd59f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -700,8 +700,11 @@ def train( | |||
700 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
701 | vae.eval() | 701 | vae.eval() |
702 | 702 | ||
703 | vae = torch.compile(vae, backend='hidet') | ||
704 | |||
703 | if compile_unet: | 705 | if compile_unet: |
704 | unet = torch.compile(unet, backend='hidet') | 706 | unet = torch.compile(unet, backend='hidet') |
707 | # unet = torch.compile(unet) | ||
705 | 708 | ||
706 | callbacks = strategy.callbacks( | 709 | callbacks = strategy.callbacks( |
707 | accelerator=accelerator, | 710 | accelerator=accelerator, |