summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
committerVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
commit9ea20241bbeb2f32199067096272e13647c512eb (patch)
tree9e0891a74d0965da75e9d3f30628b69d5ba3deaf /training/strategy/lora.py
parentFix Lora memory usage (diff)
downloadtextual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip
Fixed Lora training
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py23
1 files changed, 5 insertions, 18 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 92abaa6..bc10e58 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -89,20 +89,14 @@ def lora_strategy_callbacks(
89 @torch.no_grad() 89 @torch.no_grad()
90 def on_checkpoint(step, postfix): 90 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 91 print(f"Saving checkpoint for step {step}...")
92 orig_unet_dtype = unet.dtype 92
93 unet.to(dtype=torch.float32) 93 unet_ = accelerator.unwrap_model(unet)
94 unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) 94 unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
95 unet.to(dtype=orig_unet_dtype) 95 del unet_
96 96
97 @torch.no_grad() 97 @torch.no_grad()
98 def on_sample(step): 98 def on_sample(step):
99 orig_unet_dtype = unet.dtype
100 unet.to(dtype=weight_dtype)
101 save_samples_(step=step) 99 save_samples_(step=step)
102 unet.to(dtype=orig_unet_dtype)
103
104 if torch.cuda.is_available():
105 torch.cuda.empty_cache()
106 100
107 return TrainingCallbacks( 101 return TrainingCallbacks(
108 on_prepare=on_prepare, 102 on_prepare=on_prepare,
@@ -126,16 +120,9 @@ def lora_prepare(
126 lora_layers: AttnProcsLayers, 120 lora_layers: AttnProcsLayers,
127 **kwargs 121 **kwargs
128): 122):
129 weight_dtype = torch.float32
130 if accelerator.state.mixed_precision == "fp16":
131 weight_dtype = torch.float16
132 elif accelerator.state.mixed_precision == "bf16":
133 weight_dtype = torch.bfloat16
134
135 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 123 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
136 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) 124 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
137 unet.to(accelerator.device, dtype=weight_dtype) 125
138 text_encoder.to(accelerator.device, dtype=weight_dtype)
139 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} 126 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
140 127
141 128