From 449b828349dc0d907199577c2b550780ad84e5b2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 29 Apr 2023 10:00:50 +0200 Subject: Fixed model compilation --- train_lora.py | 6 ++++++ training/functional.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/train_lora.py b/train_lora.py index 64346bc..74afeed 100644 --- a/train_lora.py +++ b/train_lora.py @@ -455,6 +455,11 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--compile_unet", + action="store_true", + help="Compile UNet with Torch Dynamo.", + ) parser.add_argument( "--lora_rank", type=int, @@ -892,6 +897,7 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, + compile_unet=args.compile_unet, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, sample_scheduler=sample_scheduler, diff --git a/training/functional.py b/training/functional.py index e7cc20f..68ea40c 100644 --- a/training/functional.py +++ b/training/functional.py @@ -672,6 +672,7 @@ def train( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, + compile_unet: bool = False, no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, @@ -699,9 +700,8 @@ def train( vae.requires_grad_(False) vae.eval() - unet = torch.compile(unet) - text_encoder = torch.compile(text_encoder) - vae = torch.compile(vae) + if compile_unet: + unet = torch.compile(unet, backend='hidet') callbacks = strategy.callbacks( accelerator=accelerator, -- cgit v1.2.3-70-g09d2