From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 07:27:55 +0100 Subject: Fixed Lora training --- train_dreambooth.py | 12 ++++++------ train_lora.py | 25 ++++++++++++++++++------- train_ti.py | 12 ++++++------ training/strategy/lora.py | 23 +++++------------------ 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -442,6 +442,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -495,12 +501,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - trainer = partial( train, accelerator=accelerator, diff --git a/train_lora.py b/train_lora.py index b273ae1..ab1753b 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor +from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -177,6 +177,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--find_lr", action="store_true", @@ -402,6 +407,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -418,6 +429,12 @@ def main(): vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + lora_attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim @@ -467,12 +484,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - trainer = partial( train, accelerator=accelerator, diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py @@ -513,6 +513,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -564,12 +570,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - checkpoint_output_dir = output_dir.joinpath("checkpoints") trainer = partial( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 92abaa6..bc10e58 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -89,20 +89,14 @@ def lora_strategy_callbacks( @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - orig_unet_dtype = unet.dtype - unet.to(dtype=torch.float32) - unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) - unet.to(dtype=orig_unet_dtype) + + unet_ = accelerator.unwrap_model(unet) + unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) + del unet_ @torch.no_grad() def on_sample(step): - orig_unet_dtype = unet.dtype - unet.to(dtype=weight_dtype) save_samples_(step=step) - unet.to(dtype=orig_unet_dtype) - - if torch.cuda.is_available(): - torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, @@ -126,16 +120,9 @@ def lora_prepare( lora_layers: AttnProcsLayers, **kwargs ): - weight_dtype = torch.float32 - if accelerator.state.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.state.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} -- cgit v1.2.3-70-g09d2