diff options
author | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
commit | 9ea20241bbeb2f32199067096272e13647c512eb (patch) | |
tree | 9e0891a74d0965da75e9d3f30628b69d5ba3deaf | |
parent | Fix Lora memory usage (diff) | |
download | textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2 textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip |
Fixed Lora training
-rw-r--r-- | train_dreambooth.py | 12 | ||||
-rw-r--r-- | train_lora.py | 25 | ||||
-rw-r--r-- | train_ti.py | 12 | ||||
-rw-r--r-- | training/strategy/lora.py | 23 |
4 files changed, 35 insertions, 37 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -442,6 +442,12 @@ def main(): | |||
442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
443 | ) | 443 | ) |
444 | 444 | ||
445 | weight_dtype = torch.float32 | ||
446 | if args.mixed_precision == "fp16": | ||
447 | weight_dtype = torch.float16 | ||
448 | elif args.mixed_precision == "bf16": | ||
449 | weight_dtype = torch.bfloat16 | ||
450 | |||
445 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
446 | 452 | ||
447 | if args.seed is None: | 453 | if args.seed is None: |
@@ -495,12 +501,6 @@ def main(): | |||
495 | else: | 501 | else: |
496 | optimizer_class = torch.optim.AdamW | 502 | optimizer_class = torch.optim.AdamW |
497 | 503 | ||
498 | weight_dtype = torch.float32 | ||
499 | if args.mixed_precision == "fp16": | ||
500 | weight_dtype = torch.float16 | ||
501 | elif args.mixed_precision == "bf16": | ||
502 | weight_dtype = torch.bfloat16 | ||
503 | |||
504 | trainer = partial( | 504 | trainer = partial( |
505 | train, | 505 | train, |
506 | accelerator=accelerator, | 506 | accelerator=accelerator, |
diff --git a/train_lora.py b/train_lora.py index b273ae1..ab1753b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor |
17 | 17 | ||
18 | from util import load_config, load_embeddings_from_dir | 18 | from util import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
@@ -178,6 +178,11 @@ def parse_args(): | |||
178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
179 | ) | 179 | ) |
180 | parser.add_argument( | 180 | parser.add_argument( |
181 | "--gradient_checkpointing", | ||
182 | action="store_true", | ||
183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
184 | ) | ||
185 | parser.add_argument( | ||
181 | "--find_lr", | 186 | "--find_lr", |
182 | action="store_true", | 187 | action="store_true", |
183 | help="Automatically find a learning rate (no training).", | 188 | help="Automatically find a learning rate (no training).", |
@@ -402,6 +407,12 @@ def main(): | |||
402 | mixed_precision=args.mixed_precision | 407 | mixed_precision=args.mixed_precision |
403 | ) | 408 | ) |
404 | 409 | ||
410 | weight_dtype = torch.float32 | ||
411 | if args.mixed_precision == "fp16": | ||
412 | weight_dtype = torch.float16 | ||
413 | elif args.mixed_precision == "bf16": | ||
414 | weight_dtype = torch.bfloat16 | ||
415 | |||
405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 416 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
406 | 417 | ||
407 | if args.seed is None: | 418 | if args.seed is None: |
@@ -418,6 +429,12 @@ def main(): | |||
418 | vae.set_use_memory_efficient_attention_xformers(True) | 429 | vae.set_use_memory_efficient_attention_xformers(True) |
419 | unet.enable_xformers_memory_efficient_attention() | 430 | unet.enable_xformers_memory_efficient_attention() |
420 | 431 | ||
432 | if args.gradient_checkpointing: | ||
433 | unet.enable_gradient_checkpointing() | ||
434 | |||
435 | unet.to(accelerator.device, dtype=weight_dtype) | ||
436 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
437 | |||
421 | lora_attn_procs = {} | 438 | lora_attn_procs = {} |
422 | for name in unet.attn_processors.keys(): | 439 | for name in unet.attn_processors.keys(): |
423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | 440 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim |
@@ -467,12 +484,6 @@ def main(): | |||
467 | else: | 484 | else: |
468 | optimizer_class = torch.optim.AdamW | 485 | optimizer_class = torch.optim.AdamW |
469 | 486 | ||
470 | weight_dtype = torch.float32 | ||
471 | if args.mixed_precision == "fp16": | ||
472 | weight_dtype = torch.float16 | ||
473 | elif args.mixed_precision == "bf16": | ||
474 | weight_dtype = torch.bfloat16 | ||
475 | |||
476 | trainer = partial( | 487 | trainer = partial( |
477 | train, | 488 | train, |
478 | accelerator=accelerator, | 489 | accelerator=accelerator, |
diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -513,6 +513,12 @@ def main(): | |||
513 | mixed_precision=args.mixed_precision | 513 | mixed_precision=args.mixed_precision |
514 | ) | 514 | ) |
515 | 515 | ||
516 | weight_dtype = torch.float32 | ||
517 | if args.mixed_precision == "fp16": | ||
518 | weight_dtype = torch.float16 | ||
519 | elif args.mixed_precision == "bf16": | ||
520 | weight_dtype = torch.bfloat16 | ||
521 | |||
516 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
517 | 523 | ||
518 | if args.seed is None: | 524 | if args.seed is None: |
@@ -564,12 +570,6 @@ def main(): | |||
564 | else: | 570 | else: |
565 | optimizer_class = torch.optim.AdamW | 571 | optimizer_class = torch.optim.AdamW |
566 | 572 | ||
567 | weight_dtype = torch.float32 | ||
568 | if args.mixed_precision == "fp16": | ||
569 | weight_dtype = torch.float16 | ||
570 | elif args.mixed_precision == "bf16": | ||
571 | weight_dtype = torch.bfloat16 | ||
572 | |||
573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") |
574 | 574 | ||
575 | trainer = partial( | 575 | trainer = partial( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 92abaa6..bc10e58 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -89,20 +89,14 @@ def lora_strategy_callbacks( | |||
89 | @torch.no_grad() | 89 | @torch.no_grad() |
90 | def on_checkpoint(step, postfix): | 90 | def on_checkpoint(step, postfix): |
91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
92 | orig_unet_dtype = unet.dtype | 92 | |
93 | unet.to(dtype=torch.float32) | 93 | unet_ = accelerator.unwrap_model(unet) |
94 | unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | 94 | unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) |
95 | unet.to(dtype=orig_unet_dtype) | 95 | del unet_ |
96 | 96 | ||
97 | @torch.no_grad() | 97 | @torch.no_grad() |
98 | def on_sample(step): | 98 | def on_sample(step): |
99 | orig_unet_dtype = unet.dtype | ||
100 | unet.to(dtype=weight_dtype) | ||
101 | save_samples_(step=step) | 99 | save_samples_(step=step) |
102 | unet.to(dtype=orig_unet_dtype) | ||
103 | |||
104 | if torch.cuda.is_available(): | ||
105 | torch.cuda.empty_cache() | ||
106 | 100 | ||
107 | return TrainingCallbacks( | 101 | return TrainingCallbacks( |
108 | on_prepare=on_prepare, | 102 | on_prepare=on_prepare, |
@@ -126,16 +120,9 @@ def lora_prepare( | |||
126 | lora_layers: AttnProcsLayers, | 120 | lora_layers: AttnProcsLayers, |
127 | **kwargs | 121 | **kwargs |
128 | ): | 122 | ): |
129 | weight_dtype = torch.float32 | ||
130 | if accelerator.state.mixed_precision == "fp16": | ||
131 | weight_dtype = torch.float16 | ||
132 | elif accelerator.state.mixed_precision == "bf16": | ||
133 | weight_dtype = torch.bfloat16 | ||
134 | |||
135 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 123 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
136 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 124 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
137 | unet.to(accelerator.device, dtype=weight_dtype) | 125 | |
138 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
139 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | 126 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} |
140 | 127 | ||
141 | 128 | ||