summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 23:09:25 +0100
commit03303d3bddba5a27a123babdf90863e27501e6f8 (patch)
tree8266c50f8e474d92ad4b42773cb8eb7730cd24c1 /infer.py
parentOptimized Textual Inversion training by filtering dataset by existence of add... (diff)
downloadtextual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.gz
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.bz2
textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.zip
Unified loading of TI embeddings
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py34
1 files changed, 2 insertions, 32 deletions
diff --git a/infer.py b/infer.py
index f607041..1fd11e2 100644
--- a/infer.py
+++ b/infer.py
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings
27 28
28 29
29torch.backends.cuda.matmul.allow_tf32 = True 30torch.backends.cuda.matmul.allow_tf32 = True
@@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}):
180 json.dump(info, f, indent=4) 181 json.dump(info, f, indent=4)
181 182
182 183
183def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
184 print(f"Loading Textual Inversion embeddings")
185
186 embeddings_dir = Path(embeddings_dir)
187 embeddings_dir.mkdir(parents=True, exist_ok=True)
188
189 placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()]
190 tokenizer.add_tokens(placeholder_tokens)
191
192 text_encoder.resize_token_embeddings(len(tokenizer))
193
194 token_embeds = text_encoder.get_input_embeddings().weight.data
195
196 for file in embeddings_dir.iterdir():
197 if file.is_file():
198 placeholder_token = file.stem
199 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
200
201 data = torch.load(file, map_location="cpu")
202
203 assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
204
205 emb = next(iter(data.values()))
206 if len(emb.shape) == 1:
207 emb = emb.unsqueeze(0)
208
209 token_embeds[placeholder_token_id] = emb
210
211 print(f"Loaded {placeholder_token}")
212
213
214def create_pipeline(model, ti_embeddings_dir, dtype): 184def create_pipeline(model, ti_embeddings_dir, dtype):
215 print("Loading Stable Diffusion pipeline...") 185 print("Loading Stable Diffusion pipeline...")
216 186
@@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype):
220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222 192
223 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) 193 load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir))
224 194
225 pipeline = VlpnStableDiffusion( 195 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder, 196 text_encoder=text_encoder,