From 757e6af0af2f8de6da696976f3110cd70085adad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:56:37 +0100 Subject: Fix Lora memory usage --- train_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 2cb85cc..b273ae1 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -430,7 +430,7 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( + lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) -- cgit v1.2.3-54-g00ecf