diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-13 23:09:25 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-13 23:09:25 +0100 |
| commit | 03303d3bddba5a27a123babdf90863e27501e6f8 (patch) | |
| tree | 8266c50f8e474d92ad4b42773cb8eb7730cd24c1 /infer.py | |
| parent | Optimized Textual Inversion training by filtering dataset by existence of add... (diff) | |
| download | textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.gz textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.tar.bz2 textual-inversion-diff-03303d3bddba5a27a123babdf90863e27501e6f8.zip | |
Unified loading of TI embeddings
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 34 |
1 files changed, 2 insertions, 32 deletions
| @@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from common import load_text_embeddings | ||
| 27 | 28 | ||
| 28 | 29 | ||
| 29 | torch.backends.cuda.matmul.allow_tf32 = True | 30 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -180,37 +181,6 @@ def save_args(basepath, args, extra={}): | |||
| 180 | json.dump(info, f, indent=4) | 181 | json.dump(info, f, indent=4) |
| 181 | 182 | ||
| 182 | 183 | ||
| 183 | def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | ||
| 184 | print(f"Loading Textual Inversion embeddings") | ||
| 185 | |||
| 186 | embeddings_dir = Path(embeddings_dir) | ||
| 187 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
| 188 | |||
| 189 | placeholder_tokens = [file.stem for file in embeddings_dir.iterdir() if file.is_file()] | ||
| 190 | tokenizer.add_tokens(placeholder_tokens) | ||
| 191 | |||
| 192 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 193 | |||
| 194 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 195 | |||
| 196 | for file in embeddings_dir.iterdir(): | ||
| 197 | if file.is_file(): | ||
| 198 | placeholder_token = file.stem | ||
| 199 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
| 200 | |||
| 201 | data = torch.load(file, map_location="cpu") | ||
| 202 | |||
| 203 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
| 204 | |||
| 205 | emb = next(iter(data.values())) | ||
| 206 | if len(emb.shape) == 1: | ||
| 207 | emb = emb.unsqueeze(0) | ||
| 208 | |||
| 209 | token_embeds[placeholder_token_id] = emb | ||
| 210 | |||
| 211 | print(f"Loaded {placeholder_token}") | ||
| 212 | |||
| 213 | |||
| 214 | def create_pipeline(model, ti_embeddings_dir, dtype): | 184 | def create_pipeline(model, ti_embeddings_dir, dtype): |
| 215 | print("Loading Stable Diffusion pipeline...") | 185 | print("Loading Stable Diffusion pipeline...") |
| 216 | 186 | ||
| @@ -220,7 +190,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): | |||
| 220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
| 222 | 192 | ||
| 223 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) | 193 | load_text_embeddings(tokenizer, text_encoder, Path(ti_embeddings_dir)) |
| 224 | 194 | ||
| 225 | pipeline = VlpnStableDiffusion( | 195 | pipeline = VlpnStableDiffusion( |
| 226 | text_encoder=text_encoder, | 196 | text_encoder=text_encoder, |
