From fe3113451fdde72ddccfc71639f0a2a1e146209a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Mar 2023 07:11:51 +0100 Subject: Update --- train_lora.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor +from diffusers.models.cross_attention import LoRACrossAttnProcessor from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -291,6 +291,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--lora_rank", + type=int, + default=256, + help="LoRA rank.", + ) parser.add_argument( "--sample_frequency", type=int, @@ -420,10 +426,6 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() - unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -439,11 +441,18 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + lora_attn_procs[name] = LoRACrossAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.lora_rank ) unet.set_attn_processor(lora_attn_procs) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + lora_layers = AttnProcsLayers(unet.attn_processors) if args.embeddings_dir is not None: -- cgit v1.2.3-54-g00ecf