diff options
author | Volpeon <git@volpeon.ink> | 2022-12-15 20:30:59 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-15 20:30:59 +0100 |
commit | 8f4d212b3833041448678ad8a44a9a327934f74a (patch) | |
tree | 667edaef8a771a171db4a5afdae1fe8d427a2593 /infer.py | |
parent | More generic datset filter (diff) | |
download | textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.gz textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.bz2 textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.zip |
Avoid increased VRAM usage on validation
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True | |||
34 | default_args = { | 34 | default_args = { |
35 | "model": "stabilityai/stable-diffusion-2-1", | 35 | "model": "stabilityai/stable-diffusion-2-1", |
36 | "precision": "fp32", | 36 | "precision": "fp32", |
37 | "ti_embeddings_dir": "embeddings_ti", | 37 | "ti_embeddings_dir": "embeddings", |
38 | "output_dir": "output/inference", | 38 | "output_dir": "output/inference", |
39 | "config": None, | 39 | "config": None, |
40 | } | 40 | } |
@@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
192 | 192 | ||
193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) |
194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
195 | 195 | ||
196 | pipeline = VlpnStableDiffusion( | 196 | pipeline = VlpnStableDiffusion( |