summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-29 10:47:08 +0200
committerVolpeon <git@volpeon.ink>2023-04-29 10:47:08 +0200
commit5188d1e4c1d6aed3c510e5a534e39e40d8a3365d (patch)
tree9d05b8912aea05fad5c1bdc3170f38cd0ce1ed5e
parentFixed model compilation (diff)
downloadtextual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.tar.gz
textual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.tar.bz2
textual-inversion-diff-5188d1e4c1d6aed3c510e5a534e39e40d8a3365d.zip
Optional xformers
-rw-r--r--train_lora.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/train_lora.py b/train_lora.py
index 74afeed..d95dbb9 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -44,7 +44,7 @@ torch.backends.cudnn.benchmark = True
44torch._dynamo.config.log_level = logging.ERROR 44torch._dynamo.config.log_level = logging.ERROR
45 45
46hidet.torch.dynamo_config.use_tensor_core(True) 46hidet.torch.dynamo_config.use_tensor_core(True)
47hidet.torch.dynamo_config.search_space(2) 47hidet.torch.dynamo_config.search_space(1)
48 48
49 49
50def parse_args(): 50def parse_args():
@@ -461,6 +461,11 @@ def parse_args():
461 help="Compile UNet with Torch Dynamo.", 461 help="Compile UNet with Torch Dynamo.",
462 ) 462 )
463 parser.add_argument( 463 parser.add_argument(
464 "--use_xformers",
465 action="store_true",
466 help="Use xformers.",
467 )
468 parser.add_argument(
464 "--lora_rank", 469 "--lora_rank",
465 type=int, 470 type=int,
466 default=256, 471 default=256,
@@ -715,8 +720,10 @@ def main():
715 text_encoder = LoraModel(text_encoder_config, text_encoder) 720 text_encoder = LoraModel(text_encoder_config, text_encoder)
716 721
717 vae.enable_slicing() 722 vae.enable_slicing()
718 vae.set_use_memory_efficient_attention_xformers(True) 723
719 unet.enable_xformers_memory_efficient_attention() 724 if args.use_xformers:
725 vae.set_use_memory_efficient_attention_xformers(True)
726 unet.enable_xformers_memory_efficient_attention()
720 727
721 if args.gradient_checkpointing: 728 if args.gradient_checkpointing:
722 unet.enable_gradient_checkpointing() 729 unet.enable_gradient_checkpointing()