diff options
author | Volpeon <git@volpeon.ink> | 2023-03-24 10:58:10 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-24 10:58:10 +0100 |
commit | f0471f21f419a34e3fd7b7b03a4292c139fda674 (patch) | |
tree | d7cf72401c49cd7afa4ce16c4e550810ac39f454 | |
parent | Refactoring, fixed Lora training (diff) | |
download | textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.gz textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.bz2 textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.zip |
Lora fix: Save config JSON, too
-rw-r--r-- | training/strategy/lora.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3971eae..1e32114 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -3,6 +3,7 @@ from functools import partial | |||
3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | import itertools | 5 | import itertools |
6 | import json | ||
6 | 7 | ||
7 | import torch | 8 | import torch |
8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
@@ -95,6 +96,8 @@ def lora_strategy_callbacks( | |||
95 | 96 | ||
96 | accelerator.print(state_dict) | 97 | accelerator.print(state_dict) |
97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 98 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
99 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | ||
100 | json.dump(lora_config, f) | ||
98 | 101 | ||
99 | del unet_ | 102 | del unet_ |
100 | del text_encoder_ | 103 | del text_encoder_ |