summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
commit757e6af0af2f8de6da696976f3110cd70085adad (patch)
treecf219fffa359440bc8f2a2d6dd4a647715d66893 /training/strategy
parentAdd Lora (diff)
downloadtextual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip
Fix Lora memory usage
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py4
-rw-r--r--training/strategy/ti.py2
3 files changed, 1 insertions, 7 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index b4c77f3..8aaed3a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,3 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
@@ -6,7 +5,6 @@ from pathlib import Path
6import itertools 5import itertools
7 6
8import torch 7import torch
9import torch.nn as nn
10from torch.utils.data import DataLoader 8from torch.utils.data import DataLoader
11 9
12from accelerate import Accelerator 10from accelerate import Accelerator
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 88d1824..92abaa6 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index d306f18..da2b81c 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator