From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 07:27:55 +0100 Subject: Fixed Lora training --- training/strategy/lora.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) (limited to 'training/strategy/lora.py') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 92abaa6..bc10e58 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -89,20 +89,14 @@ def lora_strategy_callbacks( @torch.no_grad() def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - orig_unet_dtype = unet.dtype - unet.to(dtype=torch.float32) - unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) - unet.to(dtype=orig_unet_dtype) + + unet_ = accelerator.unwrap_model(unet) + unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) + del unet_ @torch.no_grad() def on_sample(step): - orig_unet_dtype = unet.dtype - unet.to(dtype=weight_dtype) save_samples_(step=step) - unet.to(dtype=orig_unet_dtype) - - if torch.cuda.is_available(): - torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, @@ -126,16 +120,9 @@ def lora_prepare( lora_layers: AttnProcsLayers, **kwargs ): - weight_dtype = torch.float32 - if accelerator.state.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.state.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} -- cgit v1.2.3-54-g00ecf