summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index 64346bc..74afeed 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -456,6 +456,11 @@ def parse_args():
456 ), 456 ),
457 ) 457 )
458 parser.add_argument( 458 parser.add_argument(
459 "--compile_unet",
460 action="store_true",
461 help="Compile UNet with Torch Dynamo.",
462 )
463 parser.add_argument(
459 "--lora_rank", 464 "--lora_rank",
460 type=int, 465 type=int,
461 default=256, 466 default=256,
@@ -892,6 +897,7 @@ def main():
892 noise_scheduler=noise_scheduler, 897 noise_scheduler=noise_scheduler,
893 dtype=weight_dtype, 898 dtype=weight_dtype,
894 seed=args.seed, 899 seed=args.seed,
900 compile_unet=args.compile_unet,
895 guidance_scale=args.guidance_scale, 901 guidance_scale=args.guidance_scale,
896 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 902 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
897 sample_scheduler=sample_scheduler, 903 sample_scheduler=sample_scheduler,