From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- training/strategy/lora.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'training/strategy') diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -138,6 +138,14 @@ def lora_strategy_callbacks( state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + if len(placeholder_tokens) != 0: + ti_state_dict = { + f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) + for (token, ids) + in zip(placeholder_tokens, placeholder_token_ids) + } + state_dict.update(ti_state_dict) + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) -- cgit v1.2.3-70-g09d2