diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 | 
| commit | 9ea20241bbeb2f32199067096272e13647c512eb (patch) | |
| tree | 9e0891a74d0965da75e9d3f30628b69d5ba3deaf /train_lora.py | |
| parent | Fix Lora memory usage (diff) | |
| download | textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2 textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip  | |
Fixed Lora training
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 25 | 
1 files changed, 18 insertions, 7 deletions
diff --git a/train_lora.py b/train_lora.py index b273ae1..ab1753b 100644 --- a/train_lora.py +++ b/train_lora.py  | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed | 
| 14 | from slugify import slugify | 14 | from slugify import slugify | 
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers | 
| 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 
| 17 | 17 | ||
| 18 | from util import load_config, load_embeddings_from_dir | 18 | from util import load_config, load_embeddings_from_dir | 
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter | 
| @@ -178,6 +178,11 @@ def parse_args(): | |||
| 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 
| 179 | ) | 179 | ) | 
| 180 | parser.add_argument( | 180 | parser.add_argument( | 
| 181 | "--gradient_checkpointing", | ||
| 182 | action="store_true", | ||
| 183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
| 184 | ) | ||
| 185 | parser.add_argument( | ||
| 181 | "--find_lr", | 186 | "--find_lr", | 
| 182 | action="store_true", | 187 | action="store_true", | 
| 183 | help="Automatically find a learning rate (no training).", | 188 | help="Automatically find a learning rate (no training).", | 
| @@ -402,6 +407,12 @@ def main(): | |||
| 402 | mixed_precision=args.mixed_precision | 407 | mixed_precision=args.mixed_precision | 
| 403 | ) | 408 | ) | 
| 404 | 409 | ||
| 410 | weight_dtype = torch.float32 | ||
| 411 | if args.mixed_precision == "fp16": | ||
| 412 | weight_dtype = torch.float16 | ||
| 413 | elif args.mixed_precision == "bf16": | ||
| 414 | weight_dtype = torch.bfloat16 | ||
| 415 | |||
| 405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 416 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 
| 406 | 417 | ||
| 407 | if args.seed is None: | 418 | if args.seed is None: | 
| @@ -418,6 +429,12 @@ def main(): | |||
| 418 | vae.set_use_memory_efficient_attention_xformers(True) | 429 | vae.set_use_memory_efficient_attention_xformers(True) | 
| 419 | unet.enable_xformers_memory_efficient_attention() | 430 | unet.enable_xformers_memory_efficient_attention() | 
| 420 | 431 | ||
| 432 | if args.gradient_checkpointing: | ||
| 433 | unet.enable_gradient_checkpointing() | ||
| 434 | |||
| 435 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 436 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 437 | |||
| 421 | lora_attn_procs = {} | 438 | lora_attn_procs = {} | 
| 422 | for name in unet.attn_processors.keys(): | 439 | for name in unet.attn_processors.keys(): | 
| 423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | 440 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | 
| @@ -467,12 +484,6 @@ def main(): | |||
| 467 | else: | 484 | else: | 
| 468 | optimizer_class = torch.optim.AdamW | 485 | optimizer_class = torch.optim.AdamW | 
| 469 | 486 | ||
| 470 | weight_dtype = torch.float32 | ||
| 471 | if args.mixed_precision == "fp16": | ||
| 472 | weight_dtype = torch.float16 | ||
| 473 | elif args.mixed_precision == "bf16": | ||
| 474 | weight_dtype = torch.bfloat16 | ||
| 475 | |||
| 476 | trainer = partial( | 487 | trainer = partial( | 
| 477 | train, | 488 | train, | 
| 478 | accelerator=accelerator, | 489 | accelerator=accelerator, | 
