From 449b828349dc0d907199577c2b550780ad84e5b2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 29 Apr 2023 10:00:50 +0200 Subject: Fixed model compilation --- training/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index e7cc20f..68ea40c 100644 --- a/training/functional.py +++ b/training/functional.py @@ -672,6 +672,7 @@ def train( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, + compile_unet: bool = False, no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -699,9 +700,8 @@ def train( vae.requires_grad_(False) vae.eval() - unet = torch.compile(unet) - text_encoder = torch.compile(text_encoder) - vae = torch.compile(vae) + if compile_unet: + unet = torch.compile(unet, backend='hidet') callbacks = strategy.callbacks( accelerator=accelerator, -- cgit v1.2.3-70-g09d2