diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 85b756c..5a7911c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -464,8 +464,8 @@ def main(): | |||
| 464 | tokenizer.set_dropout(args.vector_dropout) | 464 | tokenizer.set_dropout(args.vector_dropout) |
| 465 | 465 | ||
| 466 | vae.enable_slicing() | 466 | vae.enable_slicing() |
| 467 | # vae.set_use_memory_efficient_attention_xformers(True) | 467 | vae.set_use_memory_efficient_attention_xformers(True) |
| 468 | # unet.enable_xformers_memory_efficient_attention() | 468 | unet.enable_xformers_memory_efficient_attention() |
| 469 | 469 | ||
| 470 | if args.gradient_checkpointing: | 470 | if args.gradient_checkpointing: |
| 471 | unet.enable_gradient_checkpointing() | 471 | unet.enable_gradient_checkpointing() |
