diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-04 09:24:37 +0100 |
| commit | 2825ba2f2030b2fd3e841aad416a4fd28d67615a (patch) | |
| tree | 180afc9ae42c7d29b7229423d4db53183e6ad49b /textual_inversion.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.gz textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.tar.bz2 textual-inversion-diff-2825ba2f2030b2fd3e841aad416a4fd28d67615a.zip | |
Update
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index cd2d22b..1a5a8d0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -23,6 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 23 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | 24 | ||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from pipelines.util import set_use_memory_efficient_attention_xformers | ||
| 26 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| @@ -546,7 +547,8 @@ def main(): | |||
| 546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 547 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 547 | 548 | ||
| 548 | vae.enable_slicing() | 549 | vae.enable_slicing() |
| 549 | unet.set_use_memory_efficient_attention_xformers(True) | 550 | set_use_memory_efficient_attention_xformers(unet, True) |
| 551 | set_use_memory_efficient_attention_xformers(vae, True) | ||
| 550 | 552 | ||
| 551 | if args.gradient_checkpointing: | 553 | if args.gradient_checkpointing: |
| 552 | text_encoder.gradient_checkpointing_enable() | 554 | text_encoder.gradient_checkpointing_enable() |
