From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 07:27:55 +0100 Subject: Fixed Lora training --- train_lora.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index b273ae1..ab1753b 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor +from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -177,6 +177,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--find_lr", action="store_true", @@ -402,6 +407,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -418,6 +429,12 @@ def main(): vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + lora_attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim @@ -467,12 +484,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - trainer = partial( train, accelerator=accelerator, -- cgit v1.2.3-54-g00ecf