diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 22:07:06 +0100 |
| commit | 1d5038280d44a36351cb3aa21aad7a8eff220c94 (patch) | |
| tree | 55ae4c0a660f5218c072d33f2896024b47c05c6b /train_ti.py | |
| parent | Dependency cleanup/upgrades (diff) | |
| download | textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.gz textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.tar.bz2 textual-inversion-diff-1d5038280d44a36351cb3aa21aad7a8eff220c94.zip | |
Fix training
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 5c0299e..9616db6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -24,7 +24,6 @@ from slugify import slugify | |||
| 24 | 24 | ||
| 25 | from common import load_text_embeddings, load_text_embedding | 25 | from common import load_text_embeddings, load_text_embedding |
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 28 | from data.csv import CSVDataModule, CSVDataItem | 27 | from data.csv import CSVDataModule, CSVDataItem |
| 29 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 30 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| @@ -557,8 +556,8 @@ def main(): | |||
| 557 | args.pretrained_model_name_or_path, subfolder='scheduler') | 556 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 558 | 557 | ||
| 559 | vae.enable_slicing() | 558 | vae.enable_slicing() |
| 560 | set_use_memory_efficient_attention_xformers(unet, True) | 559 | vae.set_use_memory_efficient_attention_xformers(True) |
| 561 | set_use_memory_efficient_attention_xformers(vae, True) | 560 | unet.set_use_memory_efficient_attention_xformers(True) |
| 562 | 561 | ||
| 563 | if args.gradient_checkpointing: | 562 | if args.gradient_checkpointing: |
| 564 | unet.enable_gradient_checkpointing() | 563 | unet.enable_gradient_checkpointing() |
