diff options
-rw-r--r-- | infer.py | 32 |
1 files changed, 30 insertions, 2 deletions
@@ -23,6 +23,7 @@ default_args = { | |||
23 | "model": None, | 23 | "model": None, |
24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
25 | "precision": "bf16", | 25 | "precision": "bf16", |
26 | "embeddings_dir": "embeddings", | ||
26 | "output_dir": "output/inference", | 27 | "output_dir": "output/inference", |
27 | "config": None, | 28 | "config": None, |
28 | } | 29 | } |
@@ -72,6 +73,10 @@ def create_args_parser(): | |||
72 | choices=["fp32", "fp16", "bf16"], | 73 | choices=["fp32", "fp16", "bf16"], |
73 | ) | 74 | ) |
74 | parser.add_argument( | 75 | parser.add_argument( |
76 | "--embeddings_dir", | ||
77 | type=str, | ||
78 | ) | ||
79 | parser.add_argument( | ||
75 | "--output_dir", | 80 | "--output_dir", |
76 | type=str, | 81 | type=str, |
77 | ) | 82 | ) |
@@ -162,7 +167,28 @@ def save_args(basepath, args, extra={}): | |||
162 | json.dump(info, f, indent=4) | 167 | json.dump(info, f, indent=4) |
163 | 168 | ||
164 | 169 | ||
165 | def create_pipeline(model, scheduler, dtype): | 170 | def load_embeddings(tokenizer, text_encoder, embeddings_dir): |
171 | embeddings_dir = Path(embeddings_dir) | ||
172 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
173 | |||
174 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
175 | |||
176 | for file in embeddings_dir.iterdir(): | ||
177 | placeholder_token = file.stem | ||
178 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
179 | |||
180 | data = torch.load(file, map_location="cpu") | ||
181 | |||
182 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
183 | |||
184 | emb = next(iter(data.values())) | ||
185 | if len(emb.shape) == 1: | ||
186 | emb = emb.unsqueeze(0) | ||
187 | |||
188 | token_embeds[placeholder_token_id] = emb | ||
189 | |||
190 | |||
191 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | ||
166 | print("Loading Stable Diffusion pipeline...") | 192 | print("Loading Stable Diffusion pipeline...") |
167 | 193 | ||
168 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 194 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) |
@@ -170,6 +196,8 @@ def create_pipeline(model, scheduler, dtype): | |||
170 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 196 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) |
171 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 197 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) |
172 | 198 | ||
199 | load_embeddings(tokenizer, text_encoder, embeddings_dir) | ||
200 | |||
173 | if scheduler == "plms": | 201 | if scheduler == "plms": |
174 | scheduler = PNDMScheduler( | 202 | scheduler = PNDMScheduler( |
175 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 203 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
@@ -290,7 +318,7 @@ def main(): | |||
290 | output_dir = Path(args.output_dir) | 318 | output_dir = Path(args.output_dir) |
291 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 319 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
292 | 320 | ||
293 | pipeline = create_pipeline(args.model, args.scheduler, dtype) | 321 | pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) |
294 | cmd_parser = create_cmd_parser() | 322 | cmd_parser = create_cmd_parser() |
295 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 323 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
296 | cmd_prompt.cmdloop() | 324 | cmd_prompt.cmdloop() |