diff options
| -rw-r--r-- | train_dreambooth.py | 12 | ||||
| -rw-r--r-- | train_lora.py | 25 | ||||
| -rw-r--r-- | train_ti.py | 12 | ||||
| -rw-r--r-- | training/strategy/lora.py | 23 |
4 files changed, 35 insertions, 37 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -442,6 +442,12 @@ def main(): | |||
| 442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
| 443 | ) | 443 | ) |
| 444 | 444 | ||
| 445 | weight_dtype = torch.float32 | ||
| 446 | if args.mixed_precision == "fp16": | ||
| 447 | weight_dtype = torch.float16 | ||
| 448 | elif args.mixed_precision == "bf16": | ||
| 449 | weight_dtype = torch.bfloat16 | ||
| 450 | |||
| 445 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 446 | 452 | ||
| 447 | if args.seed is None: | 453 | if args.seed is None: |
| @@ -495,12 +501,6 @@ def main(): | |||
| 495 | else: | 501 | else: |
| 496 | optimizer_class = torch.optim.AdamW | 502 | optimizer_class = torch.optim.AdamW |
| 497 | 503 | ||
| 498 | weight_dtype = torch.float32 | ||
| 499 | if args.mixed_precision == "fp16": | ||
| 500 | weight_dtype = torch.float16 | ||
| 501 | elif args.mixed_precision == "bf16": | ||
| 502 | weight_dtype = torch.bfloat16 | ||
| 503 | |||
| 504 | trainer = partial( | 504 | trainer = partial( |
| 505 | train, | 505 | train, |
| 506 | accelerator=accelerator, | 506 | accelerator=accelerator, |
diff --git a/train_lora.py b/train_lora.py index b273ae1..ab1753b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util import load_config, load_embeddings_from_dir | 18 | from util import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -178,6 +178,11 @@ def parse_args(): | |||
| 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 179 | ) | 179 | ) |
| 180 | parser.add_argument( | 180 | parser.add_argument( |
| 181 | "--gradient_checkpointing", | ||
| 182 | action="store_true", | ||
| 183 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
| 184 | ) | ||
| 185 | parser.add_argument( | ||
| 181 | "--find_lr", | 186 | "--find_lr", |
| 182 | action="store_true", | 187 | action="store_true", |
| 183 | help="Automatically find a learning rate (no training).", | 188 | help="Automatically find a learning rate (no training).", |
| @@ -402,6 +407,12 @@ def main(): | |||
| 402 | mixed_precision=args.mixed_precision | 407 | mixed_precision=args.mixed_precision |
| 403 | ) | 408 | ) |
| 404 | 409 | ||
| 410 | weight_dtype = torch.float32 | ||
| 411 | if args.mixed_precision == "fp16": | ||
| 412 | weight_dtype = torch.float16 | ||
| 413 | elif args.mixed_precision == "bf16": | ||
| 414 | weight_dtype = torch.bfloat16 | ||
| 415 | |||
| 405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 416 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 406 | 417 | ||
| 407 | if args.seed is None: | 418 | if args.seed is None: |
| @@ -418,6 +429,12 @@ def main(): | |||
| 418 | vae.set_use_memory_efficient_attention_xformers(True) | 429 | vae.set_use_memory_efficient_attention_xformers(True) |
| 419 | unet.enable_xformers_memory_efficient_attention() | 430 | unet.enable_xformers_memory_efficient_attention() |
| 420 | 431 | ||
| 432 | if args.gradient_checkpointing: | ||
| 433 | unet.enable_gradient_checkpointing() | ||
| 434 | |||
| 435 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 436 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 437 | |||
| 421 | lora_attn_procs = {} | 438 | lora_attn_procs = {} |
| 422 | for name in unet.attn_processors.keys(): | 439 | for name in unet.attn_processors.keys(): |
| 423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | 440 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim |
| @@ -467,12 +484,6 @@ def main(): | |||
| 467 | else: | 484 | else: |
| 468 | optimizer_class = torch.optim.AdamW | 485 | optimizer_class = torch.optim.AdamW |
| 469 | 486 | ||
| 470 | weight_dtype = torch.float32 | ||
| 471 | if args.mixed_precision == "fp16": | ||
| 472 | weight_dtype = torch.float16 | ||
| 473 | elif args.mixed_precision == "bf16": | ||
| 474 | weight_dtype = torch.bfloat16 | ||
| 475 | |||
| 476 | trainer = partial( | 487 | trainer = partial( |
| 477 | train, | 488 | train, |
| 478 | accelerator=accelerator, | 489 | accelerator=accelerator, |
diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -513,6 +513,12 @@ def main(): | |||
| 513 | mixed_precision=args.mixed_precision | 513 | mixed_precision=args.mixed_precision |
| 514 | ) | 514 | ) |
| 515 | 515 | ||
| 516 | weight_dtype = torch.float32 | ||
| 517 | if args.mixed_precision == "fp16": | ||
| 518 | weight_dtype = torch.float16 | ||
| 519 | elif args.mixed_precision == "bf16": | ||
| 520 | weight_dtype = torch.bfloat16 | ||
| 521 | |||
| 516 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 517 | 523 | ||
| 518 | if args.seed is None: | 524 | if args.seed is None: |
| @@ -564,12 +570,6 @@ def main(): | |||
| 564 | else: | 570 | else: |
| 565 | optimizer_class = torch.optim.AdamW | 571 | optimizer_class = torch.optim.AdamW |
| 566 | 572 | ||
| 567 | weight_dtype = torch.float32 | ||
| 568 | if args.mixed_precision == "fp16": | ||
| 569 | weight_dtype = torch.float16 | ||
| 570 | elif args.mixed_precision == "bf16": | ||
| 571 | weight_dtype = torch.bfloat16 | ||
| 572 | |||
| 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") |
| 574 | 574 | ||
| 575 | trainer = partial( | 575 | trainer = partial( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 92abaa6..bc10e58 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -89,20 +89,14 @@ def lora_strategy_callbacks( | |||
| 89 | @torch.no_grad() | 89 | @torch.no_grad() |
| 90 | def on_checkpoint(step, postfix): | 90 | def on_checkpoint(step, postfix): |
| 91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
| 92 | orig_unet_dtype = unet.dtype | 92 | |
| 93 | unet.to(dtype=torch.float32) | 93 | unet_ = accelerator.unwrap_model(unet) |
| 94 | unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | 94 | unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) |
| 95 | unet.to(dtype=orig_unet_dtype) | 95 | del unet_ |
| 96 | 96 | ||
| 97 | @torch.no_grad() | 97 | @torch.no_grad() |
| 98 | def on_sample(step): | 98 | def on_sample(step): |
| 99 | orig_unet_dtype = unet.dtype | ||
| 100 | unet.to(dtype=weight_dtype) | ||
| 101 | save_samples_(step=step) | 99 | save_samples_(step=step) |
| 102 | unet.to(dtype=orig_unet_dtype) | ||
| 103 | |||
| 104 | if torch.cuda.is_available(): | ||
| 105 | torch.cuda.empty_cache() | ||
| 106 | 100 | ||
| 107 | return TrainingCallbacks( | 101 | return TrainingCallbacks( |
| 108 | on_prepare=on_prepare, | 102 | on_prepare=on_prepare, |
| @@ -126,16 +120,9 @@ def lora_prepare( | |||
| 126 | lora_layers: AttnProcsLayers, | 120 | lora_layers: AttnProcsLayers, |
| 127 | **kwargs | 121 | **kwargs |
| 128 | ): | 122 | ): |
| 129 | weight_dtype = torch.float32 | ||
| 130 | if accelerator.state.mixed_precision == "fp16": | ||
| 131 | weight_dtype = torch.float16 | ||
| 132 | elif accelerator.state.mixed_precision == "bf16": | ||
| 133 | weight_dtype = torch.bfloat16 | ||
| 134 | |||
| 135 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 123 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 136 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 124 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
| 137 | unet.to(accelerator.device, dtype=weight_dtype) | 125 | |
| 138 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 139 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | 126 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} |
| 140 | 127 | ||
| 141 | 128 | ||
