summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:56:37 +0100
commit757e6af0af2f8de6da696976f3110cd70085adad (patch)
treecf219fffa359440bc8f2a2d6dd4a647715d66893
parentAdd Lora (diff)
downloadtextual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.gz
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.tar.bz2
textual-inversion-diff-757e6af0af2f8de6da696976f3110cd70085adad.zip
Fix Lora memory usage
-rw-r--r--train_lora.py4
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py4
-rw-r--r--training/strategy/ti.py2
5 files changed, 5 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py
index 2cb85cc..b273ae1 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor 16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor
17 17
18from util import load_config, load_embeddings_from_dir 18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -430,7 +430,7 @@ def main():
430 block_id = int(name[len("down_blocks.")]) 430 block_id = int(name[len("down_blocks.")])
431 hidden_size = unet.config.block_out_channels[block_id] 431 hidden_size = unet.config.block_out_channels[block_id]
432 432
433 lora_attn_procs[name] = LoRACrossAttnProcessor( 433 lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor(
434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim 434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
435 ) 435 )
436 436
diff --git a/training/functional.py b/training/functional.py
index 8f47734..ccbb4ad 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -261,8 +261,8 @@ def loss_step(
261 eval: bool = False 261 eval: bool = False
262): 262):
263 # Convert images to latent space 263 # Convert images to latent space
264 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 264 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
265 latents = latents * 0.18215 265 latents = latents * vae.config.scaling_factor
266 266
267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
268 268
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index b4c77f3..8aaed3a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,3 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
@@ -6,7 +5,6 @@ from pathlib import Path
6import itertools 5import itertools
7 6
8import torch 7import torch
9import torch.nn as nn
10from torch.utils.data import DataLoader 8from torch.utils.data import DataLoader
11 9
12from accelerate import Accelerator 10from accelerate import Accelerator
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 88d1824..92abaa6 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index d306f18..da2b81c 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,11 +1,9 @@
1from contextlib import nullcontext
2from typing import Optional 1from typing import Optional
3from functools import partial 2from functools import partial
4from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager, nullcontext
5from pathlib import Path 4from pathlib import Path
6 5
7import torch 6import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
10 8
11from accelerate import Accelerator 9from accelerate import Accelerator