diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-24 10:58:10 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-24 10:58:10 +0100 |
| commit | f0471f21f419a34e3fd7b7b03a4292c139fda674 (patch) | |
| tree | d7cf72401c49cd7afa4ce16c4e550810ac39f454 /training | |
| parent | Refactoring, fixed Lora training (diff) | |
| download | textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.gz textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.bz2 textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.zip | |
Lora fix: Save config JSON, too
Diffstat (limited to 'training')
| -rw-r--r-- | training/strategy/lora.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3971eae..1e32114 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -3,6 +3,7 @@ from functools import partial | |||
| 3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import itertools | 5 | import itertools |
| 6 | import json | ||
| 6 | 7 | ||
| 7 | import torch | 8 | import torch |
| 8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
| @@ -95,6 +96,8 @@ def lora_strategy_callbacks( | |||
| 95 | 96 | ||
| 96 | accelerator.print(state_dict) | 97 | accelerator.print(state_dict) |
| 97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 98 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
| 99 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | ||
| 100 | json.dump(lora_config, f) | ||
| 98 | 101 | ||
| 99 | del unet_ | 102 | del unet_ |
| 100 | del text_encoder_ | 103 | del text_encoder_ |
