diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 7d10317..3aa1027 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -538,9 +538,8 @@ def main(): | |||
538 | tokenizer.set_dropout(args.vector_dropout) | 538 | tokenizer.set_dropout(args.vector_dropout) |
539 | 539 | ||
540 | vae.enable_slicing() | 540 | vae.enable_slicing() |
541 | # vae.set_use_memory_efficient_attention_xformers(True) | 541 | vae.set_use_memory_efficient_attention_xformers(True) |
542 | # unet.enable_xformers_memory_efficient_attention() | 542 | unet.enable_xformers_memory_efficient_attention() |
543 | |||
544 | # unet = torch.compile(unet) | 543 | # unet = torch.compile(unet) |
545 | 544 | ||
546 | if args.gradient_checkpointing: | 545 | if args.gradient_checkpointing: |