diff options
| -rw-r--r-- | train_lora.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py index 74afeed..d95dbb9 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -44,7 +44,7 @@ torch.backends.cudnn.benchmark = True | |||
| 44 | torch._dynamo.config.log_level = logging.ERROR | 44 | torch._dynamo.config.log_level = logging.ERROR |
| 45 | 45 | ||
| 46 | hidet.torch.dynamo_config.use_tensor_core(True) | 46 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 47 | hidet.torch.dynamo_config.search_space(2) | 47 | hidet.torch.dynamo_config.search_space(1) |
| 48 | 48 | ||
| 49 | 49 | ||
| 50 | def parse_args(): | 50 | def parse_args(): |
| @@ -461,6 +461,11 @@ def parse_args(): | |||
| 461 | help="Compile UNet with Torch Dynamo.", | 461 | help="Compile UNet with Torch Dynamo.", |
| 462 | ) | 462 | ) |
| 463 | parser.add_argument( | 463 | parser.add_argument( |
| 464 | "--use_xformers", | ||
| 465 | action="store_true", | ||
| 466 | help="Use xformers.", | ||
| 467 | ) | ||
| 468 | parser.add_argument( | ||
| 464 | "--lora_rank", | 469 | "--lora_rank", |
| 465 | type=int, | 470 | type=int, |
| 466 | default=256, | 471 | default=256, |
| @@ -715,8 +720,10 @@ def main(): | |||
| 715 | text_encoder = LoraModel(text_encoder_config, text_encoder) | 720 | text_encoder = LoraModel(text_encoder_config, text_encoder) |
| 716 | 721 | ||
| 717 | vae.enable_slicing() | 722 | vae.enable_slicing() |
| 718 | vae.set_use_memory_efficient_attention_xformers(True) | 723 | |
| 719 | unet.enable_xformers_memory_efficient_attention() | 724 | if args.use_xformers: |
| 725 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 726 | unet.enable_xformers_memory_efficient_attention() | ||
| 720 | 727 | ||
| 721 | if args.gradient_checkpointing: | 728 | if args.gradient_checkpointing: |
| 722 | unet.enable_gradient_checkpointing() | 729 | unet.enable_gradient_checkpointing() |
