diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-07 20:56:37 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-07 20:56:37 +0100 |
| commit | 757e6af0af2f8de6da696976f3110cd70085adad (patch) | |
| tree | cf219fffa359440bc8f2a2d6dd4a647715d66893 /train_lora.py | |
| parent | Add Lora (diff) | |
| download | textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2 textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip | |
Fix Lora memory usage
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 2cb85cc..b273ae1 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util import load_config, load_embeddings_from_dir | 18 | from util import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -430,7 +430,7 @@ def main(): | |||
| 430 | block_id = int(name[len("down_blocks.")]) | 430 | block_id = int(name[len("down_blocks.")]) |
| 431 | hidden_size = unet.config.block_out_channels[block_id] | 431 | hidden_size = unet.config.block_out_channels[block_id] |
| 432 | 432 | ||
| 433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | 433 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( |
| 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim |
| 435 | ) | 435 | ) |
| 436 | 436 | ||
