summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 13:26:32 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 13:26:32 +0100
commit3396ca881ed3f3521617cd9024eea56975191d32 (patch)
tree3189c3bbe77b211152d11b524d0fe3a7016441ee /infer.py
parentFix (diff)
downloadtextual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.gz
textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.tar.bz2
textual-inversion-diff-3396ca881ed3f3521617cd9024eea56975191d32.zip
Update
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/infer.py b/infer.py
index 507d0cf..9c27db4 100644
--- a/infer.py
+++ b/infer.py
@@ -25,6 +25,7 @@ from diffusers import (
25) 25)
26from transformers import CLIPTextModel 26from transformers import CLIPTextModel
27 27
28from data.keywords import prompt_to_keywords, keywords_to_prompt
28from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
29from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 31from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
@@ -49,6 +50,7 @@ default_cmds = {
49 "scheduler": "dpmsm", 50 "scheduler": "dpmsm",
50 "prompt": None, 51 "prompt": None,
51 "negative_prompt": None, 52 "negative_prompt": None,
53 "shuffle": True,
52 "image": None, 54 "image": None,
53 "image_noise": .7, 55 "image_noise": .7,
54 "width": 768, 56 "width": 768,
@@ -126,6 +128,10 @@ def create_cmd_parser():
126 nargs="*", 128 nargs="*",
127 ) 129 )
128 parser.add_argument( 130 parser.add_argument(
131 "--shuffle",
132 type=bool,
133 )
134 parser.add_argument(
129 "--image", 135 "--image",
130 type=str, 136 type=str,
131 ) 137 )
@@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir):
197 pipeline.text_encoder.text_model.embeddings, 203 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir) 204 Path(embeddings_dir)
199 ) 205 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") 206 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
201 207
202 208
203def create_pipeline(model, dtype): 209def create_pipeline(model, dtype):
@@ -228,20 +234,35 @@ def create_pipeline(model, dtype):
228 234
229 235
230@torch.inference_mode() 236@torch.inference_mode()
231def generate(output_dir, pipeline, args): 237def generate(output_dir: Path, pipeline, args):
232 if isinstance(args.prompt, str): 238 if isinstance(args.prompt, str):
233 args.prompt = [args.prompt] 239 args.prompt = [args.prompt]
234 240
241 if args.shuffle:
242 args.prompt *= args.batch_size
243 args.batch_size = 1
244 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt]
245
235 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 246 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
236 use_subdirs = len(args.prompt) != 1 247 image_dir = []
237 if use_subdirs: 248
249 if len(args.prompt) != 1:
238 if len(args.project) != 0: 250 if len(args.project) != 0:
239 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") 251 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}")
240 else: 252 else:
241 output_dir = output_dir.joinpath(now) 253 output_dir = output_dir.joinpath(now)
254
255 for prompt in args.prompt:
256 dir = output_dir.joinpath(slugify(prompt)[:100])
257 dir.mkdir(parents=True, exist_ok=True)
258 image_dir.append(dir)
259
260 with open(dir.joinpath('prompt.txt'), 'w') as f:
261 f.write(prompt)
242 else: 262 else:
243 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 263 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
244 output_dir.mkdir(parents=True, exist_ok=True) 264 output_dir.mkdir(parents=True, exist_ok=True)
265 image_dir.append(output_dir)
245 266
246 args.seed = args.seed or torch.random.seed() 267 args.seed = args.seed or torch.random.seed()
247 268
@@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args):
293 ).images 314 ).images
294 315
295 for j, image in enumerate(images): 316 for j, image in enumerate(images):
296 image_dir = output_dir 317 dir = image_dir[j % len(args.prompt)]
297 if use_subdirs: 318 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png"))
298 image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) 319 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85)
299 image_dir.mkdir(parents=True, exist_ok=True)
300 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png"))
301 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85)
302 320
303 if torch.cuda.is_available(): 321 if torch.cuda.is_available():
304 torch.cuda.empty_cache() 322 torch.cuda.empty_cache()