From 1d5038280d44a36351cb3aa21aad7a8eff220c94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 20 Dec 2022 22:07:06 +0100 Subject: Fix training --- train_ti.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 5c0299e..9616db6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -24,7 +24,6 @@ from slugify import slugify from common import load_text_embeddings, load_text_embedding from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from pipelines.util import set_use_memory_efficient_attention_xformers from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor @@ -557,8 +556,8 @@ def main(): args.pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() - set_use_memory_efficient_attention_xformers(unet, True) - set_use_memory_efficient_attention_xformers(vae, True) + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() -- cgit v1.2.3-54-g00ecf