summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
commit757e6af0af2f8de6da696976f3110cd70085adad (patch)
treecf219fffa359440bc8f2a2d6dd4a647715d66893 /training
parentAdd Lora (diff)
downloadtextual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip
Fix Lora memory usage
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py4
-rw-r--r--training/strategy/ti.py2
4 files changed, 3 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py
index 8f47734..ccbb4ad 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -261,8 +261,8 @@ def loss_step(
261 eval: bool = False 261 eval: bool = False
262): 262):
263 # Convert images to latent space 263 # Convert images to latent space
264 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 264 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
265 latents = latents * 0.18215 265 latents = latents * vae.config.scaling_factor
266 266
267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
268 268
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index b4c77f3..8aaed3a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,3 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
@@ -6,7 +5,6 @@ from pathlib import Path
6import itertools 5import itertools
7 6
8import torch 7import torch
9import torch.nn as nn
10from torch.utils.data import DataLoader 8from torch.utils.data import DataLoader
11 9
12from accelerate import Accelerator 10from accelerate import Accelerator
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 88d1824..92abaa6 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index d306f18..da2b81c 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator