summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
commit757e6af0af2f8de6da696976f3110cd70085adad (patch)
treecf219fffa359440bc8f2a2d6dd4a647715d66893 /train_lora.py
parentAdd Lora (diff)
downloadtextual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip
Fix Lora memory usage
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index 2cb85cc..b273ae1 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor 16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor
17 17
18from util import load_config, load_embeddings_from_dir 18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -430,7 +430,7 @@ def main():
430 block_id = int(name[len("down_blocks.")]) 430 block_id = int(name[len("down_blocks.")])
431 hidden_size = unet.config.block_out_channels[block_id] 431 hidden_size = unet.config.block_out_channels[block_id]
432 432
433 lora_attn_procs[name] = LoRACrossAttnProcessor( 433 lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor(
434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim 434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
435 ) 435 )
436 436