diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 68ea40c..38dd59f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -700,8 +700,11 @@ def train( | |||
| 700 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
| 701 | vae.eval() | 701 | vae.eval() |
| 702 | 702 | ||
| 703 | vae = torch.compile(vae, backend='hidet') | ||
| 704 | |||
| 703 | if compile_unet: | 705 | if compile_unet: |
| 704 | unet = torch.compile(unet, backend='hidet') | 706 | unet = torch.compile(unet, backend='hidet') |
| 707 | # unet = torch.compile(unet) | ||
| 705 | 708 | ||
| 706 | callbacks = strategy.callbacks( | 709 | callbacks = strategy.callbacks( |
| 707 | accelerator=accelerator, | 710 | accelerator=accelerator, |
