summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 10:58:10 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 10:58:10 +0100
commitf0471f21f419a34e3fd7b7b03a4292c139fda674 (patch)
treed7cf72401c49cd7afa4ce16c4e550810ac39f454 /training
parentRefactoring, fixed Lora training (diff)
downloadtextual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.gz
textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.tar.bz2
textual-inversion-diff-f0471f21f419a34e3fd7b7b03a4292c139fda674.zip
Lora fix: Save config JSON, too
Diffstat (limited to 'training')
-rw-r--r--training/strategy/lora.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 3971eae..1e32114 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -3,6 +3,7 @@ from functools import partial
3from contextlib import contextmanager 3from contextlib import contextmanager
4from pathlib import Path 4from pathlib import Path
5import itertools 5import itertools
6import json
6 7
7import torch 8import torch
8from torch.utils.data import DataLoader 9from torch.utils.data import DataLoader
@@ -95,6 +96,8 @@ def lora_strategy_callbacks(
95 96
96 accelerator.print(state_dict) 97 accelerator.print(state_dict)
97 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") 98 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
99 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
100 json.dump(lora_config, f)
98 101
99 del unet_ 102 del unet_
100 del text_encoder_ 103 del text_encoder_