diff options
author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /train_ti.py | |
parent | Dependency cleanup/upgrades (diff) | |
download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip |
Fix training
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 5c0299e..9616db6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -24,7 +24,6 @@ from slugify import slugify | |||
24 | 24 | ||
25 | from common import load_text_embeddings, load_text_embedding | 25 | from common import load_text_embeddings, load_text_embedding |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
28 | from data.csv import CSVDataModule, CSVDataItem | 27 | from data.csv import CSVDataModule, CSVDataItem |
29 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
30 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
@@ -557,8 +556,8 @@ def main(): | |||
557 | args.pretrained_model_name_or_path, subfolder='scheduler') | 556 | args.pretrained_model_name_or_path, subfolder='scheduler') |
558 | 557 | ||
559 | vae.enable_slicing() | 558 | vae.enable_slicing() |
560 | set_use_memory_efficient_attention_xformers(unet, True) | 559 | vae.set_use_memory_efficient_attention_xformers(True) |
561 | set_use_memory_efficient_attention_xformers(vae, True) | 560 | unet.set_use_memory_efficient_attention_xformers(True) |
562 | 561 | ||
563 | if args.gradient_checkpointing: | 562 | if args.gradient_checkpointing: |
564 | unet.enable_gradient_checkpointing() | 563 | unet.enable_gradient_checkpointing() |