diff options
| -rw-r--r-- | train_lora.py | 6 | ||||
| -rw-r--r-- | training/functional.py | 6 |
2 files changed, 9 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 64346bc..74afeed 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -456,6 +456,11 @@ def parse_args(): | |||
| 456 | ), | 456 | ), |
| 457 | ) | 457 | ) |
| 458 | parser.add_argument( | 458 | parser.add_argument( |
| 459 | "--compile_unet", | ||
| 460 | action="store_true", | ||
| 461 | help="Compile UNet with Torch Dynamo.", | ||
| 462 | ) | ||
| 463 | parser.add_argument( | ||
| 459 | "--lora_rank", | 464 | "--lora_rank", |
| 460 | type=int, | 465 | type=int, |
| 461 | default=256, | 466 | default=256, |
| @@ -892,6 +897,7 @@ def main(): | |||
| 892 | noise_scheduler=noise_scheduler, | 897 | noise_scheduler=noise_scheduler, |
| 893 | dtype=weight_dtype, | 898 | dtype=weight_dtype, |
| 894 | seed=args.seed, | 899 | seed=args.seed, |
| 900 | compile_unet=args.compile_unet, | ||
| 895 | guidance_scale=args.guidance_scale, | 901 | guidance_scale=args.guidance_scale, |
| 896 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 902 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 897 | sample_scheduler=sample_scheduler, | 903 | sample_scheduler=sample_scheduler, |
diff --git a/training/functional.py b/training/functional.py index e7cc20f..68ea40c 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -672,6 +672,7 @@ def train( | |||
| 672 | optimizer: torch.optim.Optimizer, | 672 | optimizer: torch.optim.Optimizer, |
| 673 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 673 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 674 | strategy: TrainingStrategy, | 674 | strategy: TrainingStrategy, |
| 675 | compile_unet: bool = False, | ||
| 675 | no_val: bool = False, | 676 | no_val: bool = False, |
| 676 | num_train_epochs: int = 100, | 677 | num_train_epochs: int = 100, |
| 677 | gradient_accumulation_steps: int = 1, | 678 | gradient_accumulation_steps: int = 1, |
| @@ -699,9 +700,8 @@ def train( | |||
| 699 | vae.requires_grad_(False) | 700 | vae.requires_grad_(False) |
| 700 | vae.eval() | 701 | vae.eval() |
| 701 | 702 | ||
| 702 | unet = torch.compile(unet) | 703 | if compile_unet: |
| 703 | text_encoder = torch.compile(text_encoder) | 704 | unet = torch.compile(unet, backend='hidet') |
| 704 | vae = torch.compile(vae) | ||
| 705 | 705 | ||
| 706 | callbacks = strategy.callbacks( | 706 | callbacks = strategy.callbacks( |
| 707 | accelerator=accelerator, | 707 | accelerator=accelerator, |
