summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-29 10:00:50 +0200
committerVolpeon <git@volpeon.ink>2023-04-29 10:00:50 +0200
commit449b828349dc0d907199577c2b550780ad84e5b2 (patch)
tree224257f78cfa8a0c8f8d51af15642bb1b7f7b5be
parentSupport torch.compile (diff)
downloadtextual-inversion-diff-449b828349dc0d907199577c2b550780ad84e5b2.tar.gz
textual-inversion-diff-449b828349dc0d907199577c2b550780ad84e5b2.tar.bz2
textual-inversion-diff-449b828349dc0d907199577c2b550780ad84e5b2.zip
Fixed model compilation
-rw-r--r--train_lora.py6
-rw-r--r--training/functional.py6
2 files changed, 9 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index 64346bc..74afeed 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -456,6 +456,11 @@ def parse_args():
456 ), 456 ),
457 ) 457 )
458 parser.add_argument( 458 parser.add_argument(
459 "--compile_unet",
460 action="store_true",
461 help="Compile UNet with Torch Dynamo.",
462 )
463 parser.add_argument(
459 "--lora_rank", 464 "--lora_rank",
460 type=int, 465 type=int,
461 default=256, 466 default=256,
@@ -892,6 +897,7 @@ def main():
892 noise_scheduler=noise_scheduler, 897 noise_scheduler=noise_scheduler,
893 dtype=weight_dtype, 898 dtype=weight_dtype,
894 seed=args.seed, 899 seed=args.seed,
900 compile_unet=args.compile_unet,
895 guidance_scale=args.guidance_scale, 901 guidance_scale=args.guidance_scale,
896 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 902 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
897 sample_scheduler=sample_scheduler, 903 sample_scheduler=sample_scheduler,
diff --git a/training/functional.py b/training/functional.py
index e7cc20f..68ea40c 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -672,6 +672,7 @@ def train(
672 optimizer: torch.optim.Optimizer, 672 optimizer: torch.optim.Optimizer,
673 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 673 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
674 strategy: TrainingStrategy, 674 strategy: TrainingStrategy,
675 compile_unet: bool = False,
675 no_val: bool = False, 676 no_val: bool = False,
676 num_train_epochs: int = 100, 677 num_train_epochs: int = 100,
677 gradient_accumulation_steps: int = 1, 678 gradient_accumulation_steps: int = 1,
@@ -699,9 +700,8 @@ def train(
699 vae.requires_grad_(False) 700 vae.requires_grad_(False)
700 vae.eval() 701 vae.eval()
701 702
702 unet = torch.compile(unet) 703 if compile_unet:
703 text_encoder = torch.compile(text_encoder) 704 unet = torch.compile(unet, backend='hidet')
704 vae = torch.compile(vae)
705 705
706 callbacks = strategy.callbacks( 706 callbacks = strategy.callbacks(
707 accelerator=accelerator, 707 accelerator=accelerator,