diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 85b756c..5a7911c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -464,8 +464,8 @@ def main(): | |||
464 | tokenizer.set_dropout(args.vector_dropout) | 464 | tokenizer.set_dropout(args.vector_dropout) |
465 | 465 | ||
466 | vae.enable_slicing() | 466 | vae.enable_slicing() |
467 | # vae.set_use_memory_efficient_attention_xformers(True) | 467 | vae.set_use_memory_efficient_attention_xformers(True) |
468 | # unet.enable_xformers_memory_efficient_attention() | 468 | unet.enable_xformers_memory_efficient_attention() |
469 | 469 | ||
470 | if args.gradient_checkpointing: | 470 | if args.gradient_checkpointing: |
471 | unet.enable_gradient_checkpointing() | 471 | unet.enable_gradient_checkpointing() |