diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py index 64346bc..74afeed 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -456,6 +456,11 @@ def parse_args(): | |||
456 | ), | 456 | ), |
457 | ) | 457 | ) |
458 | parser.add_argument( | 458 | parser.add_argument( |
459 | "--compile_unet", | ||
460 | action="store_true", | ||
461 | help="Compile UNet with Torch Dynamo.", | ||
462 | ) | ||
463 | parser.add_argument( | ||
459 | "--lora_rank", | 464 | "--lora_rank", |
460 | type=int, | 465 | type=int, |
461 | default=256, | 466 | default=256, |
@@ -892,6 +897,7 @@ def main(): | |||
892 | noise_scheduler=noise_scheduler, | 897 | noise_scheduler=noise_scheduler, |
893 | dtype=weight_dtype, | 898 | dtype=weight_dtype, |
894 | seed=args.seed, | 899 | seed=args.seed, |
900 | compile_unet=args.compile_unet, | ||
895 | guidance_scale=args.guidance_scale, | 901 | guidance_scale=args.guidance_scale, |
896 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 902 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
897 | sample_scheduler=sample_scheduler, | 903 | sample_scheduler=sample_scheduler, |