summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
commit8f4d212b3833041448678ad8a44a9a327934f74a (patch)
tree667edaef8a771a171db4a5afdae1fe8d427a2593 /infer.py
parentMore generic datset filter (diff)
downloadtextual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.gz
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.bz2
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.zip
Avoid increased VRAM usage on validation
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/infer.py b/infer.py
index efeb24d..420cb83 100644
--- a/infer.py
+++ b/infer.py
@@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True
34default_args = { 34default_args = {
35 "model": "stabilityai/stable-diffusion-2-1", 35 "model": "stabilityai/stable-diffusion-2-1",
36 "precision": "fp32", 36 "precision": "fp32",
37 "ti_embeddings_dir": "embeddings_ti", 37 "ti_embeddings_dir": "embeddings",
38 "output_dir": "output/inference", 38 "output_dir": "output/inference",
39 "config": None, 39 "config": None,
40} 40}
@@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype):
190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
192 192
193 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 193 added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir))
194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
195 195
196 pipeline = VlpnStableDiffusion( 196 pipeline = VlpnStableDiffusion(