diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 7d10317..3aa1027 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -538,9 +538,8 @@ def main(): | |||
| 538 | tokenizer.set_dropout(args.vector_dropout) | 538 | tokenizer.set_dropout(args.vector_dropout) |
| 539 | 539 | ||
| 540 | vae.enable_slicing() | 540 | vae.enable_slicing() |
| 541 | # vae.set_use_memory_efficient_attention_xformers(True) | 541 | vae.set_use_memory_efficient_attention_xformers(True) |
| 542 | # unet.enable_xformers_memory_efficient_attention() | 542 | unet.enable_xformers_memory_efficient_attention() |
| 543 | |||
| 544 | # unet = torch.compile(unet) | 543 | # unet = torch.compile(unet) |
| 545 | 544 | ||
| 546 | if args.gradient_checkpointing: | 545 | if args.gradient_checkpointing: |
