summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1f0a117..3f4dbbc 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -138,6 +138,14 @@ def lora_strategy_callbacks(
138 state_dict.update(text_encoder_state_dict) 138 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
140 140
141 if len(placeholder_tokens) != 0:
142 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 }
147 state_dict.update(ti_state_dict)
148
141 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
142 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 150 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
143 json.dump(lora_config, f) 151 json.dump(lora_config, f)