diff options
author | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
commit | 6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch) | |
tree | 6c65817b9351453bfb5366f7010f8d87659c0dd0 /training/strategy | |
parent | Fix cycle loop (diff) | |
download | textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2 textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip |
Update
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -138,6 +138,14 @@ def lora_strategy_callbacks( | |||
138 | state_dict.update(text_encoder_state_dict) | 138 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
140 | 140 | ||
141 | if len(placeholder_tokens) != 0: | ||
142 | ti_state_dict = { | ||
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | ||
144 | for (token, ids) | ||
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | ||
147 | state_dict.update(ti_state_dict) | ||
148 | |||
141 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
142 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
143 | json.dump(lora_config, f) | 151 | json.dump(lora_config, f) |