diff options
Diffstat (limited to 'training/strategy/lora.py')
| -rw-r--r-- | training/strategy/lora.py | 23 |
1 files changed, 5 insertions, 18 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 92abaa6..bc10e58 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -89,20 +89,14 @@ def lora_strategy_callbacks( | |||
| 89 | @torch.no_grad() | 89 | @torch.no_grad() |
| 90 | def on_checkpoint(step, postfix): | 90 | def on_checkpoint(step, postfix): |
| 91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
| 92 | orig_unet_dtype = unet.dtype | 92 | |
| 93 | unet.to(dtype=torch.float32) | 93 | unet_ = accelerator.unwrap_model(unet) |
| 94 | unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | 94 | unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) |
| 95 | unet.to(dtype=orig_unet_dtype) | 95 | del unet_ |
| 96 | 96 | ||
| 97 | @torch.no_grad() | 97 | @torch.no_grad() |
| 98 | def on_sample(step): | 98 | def on_sample(step): |
| 99 | orig_unet_dtype = unet.dtype | ||
| 100 | unet.to(dtype=weight_dtype) | ||
| 101 | save_samples_(step=step) | 99 | save_samples_(step=step) |
| 102 | unet.to(dtype=orig_unet_dtype) | ||
| 103 | |||
| 104 | if torch.cuda.is_available(): | ||
| 105 | torch.cuda.empty_cache() | ||
| 106 | 100 | ||
| 107 | return TrainingCallbacks( | 101 | return TrainingCallbacks( |
| 108 | on_prepare=on_prepare, | 102 | on_prepare=on_prepare, |
| @@ -126,16 +120,9 @@ def lora_prepare( | |||
| 126 | lora_layers: AttnProcsLayers, | 120 | lora_layers: AttnProcsLayers, |
| 127 | **kwargs | 121 | **kwargs |
| 128 | ): | 122 | ): |
| 129 | weight_dtype = torch.float32 | ||
| 130 | if accelerator.state.mixed_precision == "fp16": | ||
| 131 | weight_dtype = torch.float16 | ||
| 132 | elif accelerator.state.mixed_precision == "bf16": | ||
| 133 | weight_dtype = torch.bfloat16 | ||
| 134 | |||
| 135 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 123 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 136 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 124 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
| 137 | unet.to(accelerator.device, dtype=weight_dtype) | 125 | |
| 138 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 139 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | 126 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} |
| 140 | 127 | ||
| 141 | 128 | ||
