summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
committerVolpeon <git@volpeon.ink>2022-12-20 22:07:06 +0100
commit1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch)
tree55ae4c0a660f5218c072d33f2896024b47c05c6b /train_ti.py
parentDependency cleanup/upgrades (diff)
downloadtextual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2
textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip
Fix training
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 5c0299e..9616db6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -24,7 +24,6 @@ from slugify import slugify
24 24
25from common import load_text_embeddings, load_text_embedding 25from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers
28from data.csv import CSVDataModule, CSVDataItem 27from data.csv import CSVDataModule, CSVDataItem
29from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
30from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -557,8 +556,8 @@ def main():
557 args.pretrained_model_name_or_path, subfolder='scheduler') 556 args.pretrained_model_name_or_path, subfolder='scheduler')
558 557
559 vae.enable_slicing() 558 vae.enable_slicing()
560 set_use_memory_efficient_attention_xformers(unet, True) 559 vae.set_use_memory_efficient_attention_xformers(True)
561 set_use_memory_efficient_attention_xformers(vae, True) 560 unet.set_use_memory_efficient_attention_xformers(True)
562 561
563 if args.gradient_checkpointing: 562 if args.gradient_checkpointing:
564 unet.enable_gradient_checkpointing() 563 unet.enable_gradient_checkpointing()