diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1e32114..8905171 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -12,6 +12,7 @@ from accelerate import Accelerator | |||
12 | from transformers import CLIPTextModel | 12 | from transformers import CLIPTextModel |
13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
14 | from peft import get_peft_model_state_dict | 14 | from peft import get_peft_model_state_dict |
15 | from safetensors.torch import save_file | ||
15 | 16 | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 17 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
@@ -78,6 +79,9 @@ def lora_strategy_callbacks( | |||
78 | 79 | ||
79 | @torch.no_grad() | 80 | @torch.no_grad() |
80 | def on_checkpoint(step, postfix): | 81 | def on_checkpoint(step, postfix): |
82 | if postfix != "end": | ||
83 | return | ||
84 | |||
81 | print(f"Saving checkpoint for step {step}...") | 85 | print(f"Saving checkpoint for step {step}...") |
82 | 86 | ||
83 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 87 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
@@ -94,26 +98,23 @@ def lora_strategy_callbacks( | |||
94 | state_dict.update(text_encoder_state_dict) | 98 | state_dict.update(text_encoder_state_dict) |
95 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 99 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
96 | 100 | ||
97 | accelerator.print(state_dict) | 101 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
98 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | ||
99 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 102 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
100 | json.dump(lora_config, f) | 103 | json.dump(lora_config, f) |
101 | 104 | ||
102 | del unet_ | 105 | del unet_ |
103 | del text_encoder_ | 106 | del text_encoder_ |
104 | 107 | ||
108 | if torch.cuda.is_available(): | ||
109 | torch.cuda.empty_cache() | ||
110 | |||
105 | @torch.no_grad() | 111 | @torch.no_grad() |
106 | def on_sample(step): | 112 | def on_sample(step): |
107 | vae_dtype = vae.dtype | ||
108 | vae.to(dtype=text_encoder.dtype) | ||
109 | |||
110 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 113 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
111 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 114 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
112 | 115 | ||
113 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 116 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
114 | 117 | ||
115 | vae.to(dtype=vae_dtype) | ||
116 | |||
117 | del unet_ | 118 | del unet_ |
118 | del text_encoder_ | 119 | del text_encoder_ |
119 | 120 | ||