diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/lora.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 3971eae..1e32114 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -3,6 +3,7 @@ from functools import partial | |||
| 3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import itertools | 5 | import itertools |
| 6 | import json | ||
| 6 | 7 | ||
| 7 | import torch | 8 | import torch |
| 8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
| @@ -95,6 +96,8 @@ def lora_strategy_callbacks( | |||
| 95 | 96 | ||
| 96 | accelerator.print(state_dict) | 97 | accelerator.print(state_dict) |
| 97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 98 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
| 99 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | ||
| 100 | json.dump(lora_config, f) | ||
| 98 | 101 | ||
| 99 | del unet_ | 102 | del unet_ |
| 100 | del text_encoder_ | 103 | del text_encoder_ |
