diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 34 |
1 files changed, 2 insertions, 32 deletions
@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from common import load_text_embeddings | ||
27 | 28 | ||
28 | 29 | ||
29 | torch.backends.cuda.matmul.allow_tf32 = True | 30 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}): | |||
180 | json.dump(info, f, indent=4) | 181 | json.dump(info, f, indent=4) |
181 | 182 | ||
182 | 183 | ||
183 | def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | ||
184 | print(f"Loading Textual Inversion embeddings") | ||
185 | |||
186 | embeddings_dir = Path(embeddings_dir) | ||
187 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
188 | |||
189 | placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] | ||
190 | tokenizer.add_tokens(placeholder_tokens) | ||
191 | |||
192 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
193 | |||
194 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
195 | |||
196 | for file in embeddings_dir.iterdir(): | ||
197 | if file.is_file(): | ||
198 | placeholder_token = file.stem | ||
199 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
200 | |||
201 | data = torch.load(file, map_location="cpu") | ||
202 | |||
203 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
204 | |||
205 | emb = next(iter(data.values())) | ||
206 | if len(emb.shape) == 1: | ||
207 | emb = emb.unsqueeze(0) | ||
208 | |||
209 | token_embeds[placeholder_token_id] = emb | ||
210 | |||
211 | print(f"Loaded {placeholder_token}") | ||
212 | |||
213 | |||
214 | def create_pipeline(model, ti_embeddings_dir, dtype): | 184 | def create_pipeline(model, ti_embeddings_dir, dtype): |
215 | print("Loading Stable Diffusion pipeline...") | 185 | print("Loading Stable Diffusion pipeline...") |
216 | 186 | ||
@@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): | |||
220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
222 | 192 | ||
223 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 193 | load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) |
224 | 194 | ||
225 | pipeline = VlpnStableDiffusion( | 195 | pipeline = VlpnStableDiffusion( |
226 | text_encoder=text_encoder, | 196 | text_encoder=text_encoder, |