diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index d1defb3..7d10317 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -538,8 +538,10 @@ def main(): | |||
538 | tokenizer.set_dropout(args.vector_dropout) | 538 | tokenizer.set_dropout(args.vector_dropout) |
539 | 539 | ||
540 | vae.enable_slicing() | 540 | vae.enable_slicing() |
541 | vae.set_use_memory_efficient_attention_xformers(True) | 541 | # vae.set_use_memory_efficient_attention_xformers(True) |
542 | unet.enable_xformers_memory_efficient_attention() | 542 | # unet.enable_xformers_memory_efficient_attention() |
543 | |||
544 | # unet = torch.compile(unet) | ||
543 | 545 | ||
544 | if args.gradient_checkpointing: | 546 | if args.gradient_checkpointing: |
545 | unet.enable_gradient_checkpointing() | 547 | unet.enable_gradient_checkpointing() |