diff options
author | Volpeon <git@volpeon.ink> | 2023-02-07 20:56:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-07 20:56:37 +0100 |
commit | 757e6af0af2f8de6da696976f3110cd70085adad (patch) | |
tree | cf219fffa359440bc8f2a2d6dd4a647715d66893 /training/strategy | |
parent | Add Lora (diff) | |
download | textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2 textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip |
Fix Lora memory usage
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 1 insertions, 7 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index b4c77f3..8aaed3a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -1,4 +1,3 @@ | |||
1 | from contextlib import nullcontext | ||
2 | from typing import Optional | 1 | from typing import Optional |
3 | from functools import partial | 2 | from functools import partial |
4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager, nullcontext |
@@ -6,7 +5,6 @@ from pathlib import Path | |||
6 | import itertools | 5 | import itertools |
7 | 6 | ||
8 | import torch | 7 | import torch |
9 | import torch.nn as nn | ||
10 | from torch.utils.data import DataLoader | 8 | from torch.utils.data import DataLoader |
11 | 9 | ||
12 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 88d1824..92abaa6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -1,11 +1,9 @@ | |||
1 | from contextlib import nullcontext | ||
2 | from typing import Optional | 1 | from typing import Optional |
3 | from functools import partial | 2 | from functools import partial |
4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager |
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | 5 | ||
7 | import torch | 6 | import torch |
8 | import torch.nn as nn | ||
9 | from torch.utils.data import DataLoader | 7 | from torch.utils.data import DataLoader |
10 | 8 | ||
11 | from accelerate import Accelerator | 9 | from accelerate import Accelerator |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index d306f18..da2b81c 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,11 +1,9 @@ | |||
1 | from contextlib import nullcontext | ||
2 | from typing import Optional | 1 | from typing import Optional |
3 | from functools import partial | 2 | from functools import partial |
4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager, nullcontext |
5 | from pathlib import Path | 4 | from pathlib import Path |
6 | 5 | ||
7 | import torch | 6 | import torch |
8 | import torch.nn as nn | ||
9 | from torch.utils.data import DataLoader | 7 | from torch.utils.data import DataLoader |
10 | 8 | ||
11 | from accelerate import Accelerator | 9 | from accelerate import Accelerator |