summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 21:16:08 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 21:16:08 +0200
commit19a013ba9efaad53b7fc0eef647671e7143efc2a (patch)
tree45d1f3d7289c3ccf0c906f03d035a4df88ca2161 /infer.py
parentMulti-vector TI was broken (diff)
downloadtextual-inversion-diff-19a013ba9efaad53b7fc0eef647671e7143efc2a.tar.gz
textual-inversion-diff-19a013ba9efaad53b7fc0eef647671e7143efc2a.tar.bz2
textual-inversion-diff-19a013ba9efaad53b7fc0eef647671e7143efc2a.zip
Inference: Add support for embeddings
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py32
1 files changed, 30 insertions, 2 deletions
diff --git a/infer.py b/infer.py
index 3dc0f32..3487e5a 100644
--- a/infer.py
+++ b/infer.py
@@ -23,6 +23,7 @@ default_args = {
23 "model": None, 23 "model": None,
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "bf16", 25 "precision": "bf16",
26 "embeddings_dir": "embeddings",
26 "output_dir": "output/inference", 27 "output_dir": "output/inference",
27 "config": None, 28 "config": None,
28} 29}
@@ -72,6 +73,10 @@ def create_args_parser():
72 choices=["fp32", "fp16", "bf16"], 73 choices=["fp32", "fp16", "bf16"],
73 ) 74 )
74 parser.add_argument( 75 parser.add_argument(
76 "--embeddings_dir",
77 type=str,
78 )
79 parser.add_argument(
75 "--output_dir", 80 "--output_dir",
76 type=str, 81 type=str,
77 ) 82 )
@@ -162,7 +167,28 @@ def save_args(basepath, args, extra={}):
162 json.dump(info, f, indent=4) 167 json.dump(info, f, indent=4)
163 168
164 169
165def create_pipeline(model, scheduler, dtype): 170def load_embeddings(tokenizer, text_encoder, embeddings_dir):
171 embeddings_dir = Path(embeddings_dir)
172 embeddings_dir.mkdir(parents=True, exist_ok=True)
173
174 token_embeds = text_encoder.get_input_embeddings().weight.data
175
176 for file in embeddings_dir.iterdir():
177 placeholder_token = file.stem
178 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
179
180 data = torch.load(file, map_location="cpu")
181
182 assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
183
184 emb = next(iter(data.values()))
185 if len(emb.shape) == 1:
186 emb = emb.unsqueeze(0)
187
188 token_embeds[placeholder_token_id] = emb
189
190
191def create_pipeline(model, scheduler, embeddings_dir, dtype):
166 print("Loading Stable Diffusion pipeline...") 192 print("Loading Stable Diffusion pipeline...")
167 193
168 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) 194 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype)
@@ -170,6 +196,8 @@ def create_pipeline(model, scheduler, dtype):
170 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) 196 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype)
171 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) 197 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype)
172 198
199 load_embeddings(tokenizer, text_encoder, embeddings_dir)
200
173 if scheduler == "plms": 201 if scheduler == "plms":
174 scheduler = PNDMScheduler( 202 scheduler = PNDMScheduler(
175 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 203 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
@@ -290,7 +318,7 @@ def main():
290 output_dir = Path(args.output_dir) 318 output_dir = Path(args.output_dir)
291 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 319 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
292 320
293 pipeline = create_pipeline(args.model, args.scheduler, dtype) 321 pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype)
294 cmd_parser = create_cmd_parser() 322 cmd_parser = create_cmd_parser()
295 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 323 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
296 cmd_prompt.cmdloop() 324 cmd_prompt.cmdloop()