diff options
-rw-r--r-- | training/strategy/lora.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3971eae..1e32114 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -3,6 +3,7 @@ from functools import partial | |||
3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | import itertools | 5 | import itertools |
6 | import json | ||
6 | 7 | ||
7 | import torch | 8 | import torch |
8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
@@ -95,6 +96,8 @@ def lora_strategy_callbacks( | |||
95 | 96 | ||
96 | accelerator.print(state_dict) | 97 | accelerator.print(state_dict) |
97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 98 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
99 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | ||
100 | json.dump(lora_config, f) | ||
98 | 101 | ||
99 | del unet_ | 102 | del unet_ |
100 | del text_encoder_ | 103 | del text_encoder_ |