diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index e7cc20f..68ea40c 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -672,6 +672,7 @@ def train( | |||
672 | optimizer: torch.optim.Optimizer, | 672 | optimizer: torch.optim.Optimizer, |
673 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 673 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
674 | strategy: TrainingStrategy, | 674 | strategy: TrainingStrategy, |
675 | compile_unet: bool = False, | ||
675 | no_val: bool = False, | 676 | no_val: bool = False, |
676 | num_train_epochs: int = 100, | 677 | num_train_epochs: int = 100, |
677 | gradient_accumulation_steps: int = 1, | 678 | gradient_accumulation_steps: int = 1, |
@@ -699,9 +700,8 @@ def train( | |||
699 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
700 | vae.eval() | 701 | vae.eval() |
701 | 702 | ||
702 | unet = torch.compile(unet) | 703 | if compile_unet: |
703 | text_encoder = torch.compile(text_encoder) | 704 | unet = torch.compile(unet, backend='hidet') |
704 | vae = torch.compile(vae) | ||
705 | 705 | ||
706 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
707 | accelerator=accelerator, | 707 | accelerator=accelerator, |