From 757e6af0af2f8de6da696976f3110cd70085adad Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:56:37 +0100 Subject: Fix Lora memory usage --- train_lora.py | 4 ++-- training/functional.py | 4 ++-- training/strategy/dreambooth.py | 2 -- training/strategy/lora.py | 4 +--- training/strategy/ti.py | 2 -- 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/train_lora.py b/train_lora.py index 2cb85cc..b273ae1 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify from diffusers.loaders import AttnProcsLayers -from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter @@ -430,7 +430,7 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( + lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) diff --git a/training/functional.py b/training/functional.py index 8f47734..ccbb4ad 100644 --- a/training/functional.py +++ b/training/functional.py @@ -261,8 +261,8 @@ def loss_step( eval: bool = False ): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * vae.config.scaling_factor generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index b4c77f3..8aaed3a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -1,4 +1,3 @@ -from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext @@ -6,7 +5,6 @@ from pathlib import Path import itertools import torch -import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 88d1824..92abaa6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -1,11 +1,9 @@ -from contextlib import nullcontext from typing import Optional from functools import partial -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from pathlib import Path import torch -import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator diff --git a/training/strategy/ti.py b/training/strategy/ti.py index d306f18..da2b81c 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,11 +1,9 @@ -from contextlib import nullcontext from typing import Optional from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path import torch -import torch.nn as nn from torch.utils.data import DataLoader from accelerate import Accelerator -- cgit v1.2.3-70-g09d2