From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 8 Feb 2023 07:27:55 +0100
Subject: Fixed Lora training

---
 training/strategy/lora.py | 23 +++++------------------
 1 file changed, 5 insertions(+), 18 deletions(-)

(limited to 'training/strategy')

diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 92abaa6..bc10e58 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -89,20 +89,14 @@ def lora_strategy_callbacks(
     @torch.no_grad()
     def on_checkpoint(step, postfix):
         print(f"Saving checkpoint for step {step}...")
-        orig_unet_dtype = unet.dtype
-        unet.to(dtype=torch.float32)
-        unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
-        unet.to(dtype=orig_unet_dtype)
+
+        unet_ = accelerator.unwrap_model(unet)
+        unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
+        del unet_
 
     @torch.no_grad()
     def on_sample(step):
-        orig_unet_dtype = unet.dtype
-        unet.to(dtype=weight_dtype)
         save_samples_(step=step)
-        unet.to(dtype=orig_unet_dtype)
-
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
 
     return TrainingCallbacks(
         on_prepare=on_prepare,
@@ -126,16 +120,9 @@ def lora_prepare(
     lora_layers: AttnProcsLayers,
     **kwargs
 ):
-    weight_dtype = torch.float32
-    if accelerator.state.mixed_precision == "fp16":
-        weight_dtype = torch.float16
-    elif accelerator.state.mixed_precision == "bf16":
-        weight_dtype = torch.bfloat16
-
     lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
         lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
-    unet.to(accelerator.device, dtype=weight_dtype)
-    text_encoder.to(accelerator.device, dtype=weight_dtype)
+
     return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
 
 
-- 
cgit v1.2.3-70-g09d2