diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -138,6 +138,14 @@ def lora_strategy_callbacks( | |||
138 | state_dict.update(text_encoder_state_dict) | 138 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
140 | 140 | ||
141 | if len(placeholder_tokens) != 0: | ||
142 | ti_state_dict = { | ||
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | ||
144 | for (token, ids) | ||
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | ||
147 | state_dict.update(ti_state_dict) | ||
148 | |||
141 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
142 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
143 | json.dump(lora_config, f) | 151 | json.dump(lora_config, f) |