diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 8a06ae8..330bcd6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -421,8 +421,8 @@ def main(): | |||
421 | args.pretrained_model_name_or_path) | 421 | args.pretrained_model_name_or_path) |
422 | 422 | ||
423 | vae.enable_slicing() | 423 | vae.enable_slicing() |
424 | # vae.set_use_memory_efficient_attention_xformers(True) | 424 | vae.set_use_memory_efficient_attention_xformers(True) |
425 | # unet.enable_xformers_memory_efficient_attention() | 425 | unet.enable_xformers_memory_efficient_attention() |
426 | 426 | ||
427 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | unet.to(accelerator.device, dtype=weight_dtype) |
428 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | text_encoder.to(accelerator.device, dtype=weight_dtype) |