summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
committerVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
commit9ea20241bbeb2f32199067096272e13647c512eb (patch)
tree9e0891a74d0965da75e9d3f30628b69d5ba3deaf /train_lora.py
parentFix Lora memory usage (diff)
downloadtextual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip
Fixed Lora training
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py
index b273ae1..ab1753b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor 16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor
17 17
18from util import load_config, load_embeddings_from_dir 18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -178,6 +178,11 @@ def parse_args():
178 help="Number of updates steps to accumulate before performing a backward/update pass.", 178 help="Number of updates steps to accumulate before performing a backward/update pass.",
179 ) 179 )
180 parser.add_argument( 180 parser.add_argument(
181 "--gradient_checkpointing",
182 action="store_true",
183 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184 )
185 parser.add_argument(
181 "--find_lr", 186 "--find_lr",
182 action="store_true", 187 action="store_true",
183 help="Automatically find a learning rate (no training).", 188 help="Automatically find a learning rate (no training).",
@@ -402,6 +407,12 @@ def main():
402 mixed_precision=args.mixed_precision 407 mixed_precision=args.mixed_precision
403 ) 408 )
404 409
410 weight_dtype = torch.float32
411 if args.mixed_precision == "fp16":
412 weight_dtype = torch.float16
413 elif args.mixed_precision == "bf16":
414 weight_dtype = torch.bfloat16
415
405 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 416 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
406 417
407 if args.seed is None: 418 if args.seed is None:
@@ -418,6 +429,12 @@ def main():
418 vae.set_use_memory_efficient_attention_xformers(True) 429 vae.set_use_memory_efficient_attention_xformers(True)
419 unet.enable_xformers_memory_efficient_attention() 430 unet.enable_xformers_memory_efficient_attention()
420 431
432 if args.gradient_checkpointing:
433 unet.enable_gradient_checkpointing()
434
435 unet.to(accelerator.device, dtype=weight_dtype)
436 text_encoder.to(accelerator.device, dtype=weight_dtype)
437
421 lora_attn_procs = {} 438 lora_attn_procs = {}
422 for name in unet.attn_processors.keys(): 439 for name in unet.attn_processors.keys():
423 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 440 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
@@ -467,12 +484,6 @@ def main():
467 else: 484 else:
468 optimizer_class = torch.optim.AdamW 485 optimizer_class = torch.optim.AdamW
469 486
470 weight_dtype = torch.float32
471 if args.mixed_precision == "fp16":
472 weight_dtype = torch.float16
473 elif args.mixed_precision == "bf16":
474 weight_dtype = torch.bfloat16
475
476 trainer = partial( 487 trainer = partial(
477 train, 488 train,
478 accelerator=accelerator, 489 accelerator=accelerator,