summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/infer.py b/infer.py
index d3d5f1b..2b07b21 100644
--- a/infer.py
+++ b/infer.py
@@ -238,16 +238,15 @@ def create_pipeline(model, dtype):
238 return pipeline 238 return pipeline
239 239
240 240
241def shuffle_prompts(prompts: list[str]) -> list[str]:
242 return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts]
243
244
241@torch.inference_mode() 245@torch.inference_mode()
242def generate(output_dir: Path, pipeline, args): 246def generate(output_dir: Path, pipeline, args):
243 if isinstance(args.prompt, str): 247 if isinstance(args.prompt, str):
244 args.prompt = [args.prompt] 248 args.prompt = [args.prompt]
245 249
246 if args.shuffle:
247 args.prompt *= args.batch_size
248 args.batch_size = 1
249 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt]
250
251 args.prompt = [args.template.format(prompt) for prompt in args.prompt] 250 args.prompt = [args.template.format(prompt) for prompt in args.prompt]
252 251
253 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 252 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
@@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args):
263 dir = output_dir.joinpath(slugify(prompt)[:100]) 262 dir = output_dir.joinpath(slugify(prompt)[:100])
264 dir.mkdir(parents=True, exist_ok=True) 263 dir.mkdir(parents=True, exist_ok=True)
265 image_dir.append(dir) 264 image_dir.append(dir)
266
267 with open(dir.joinpath('prompt.txt'), 'w') as f:
268 f.write(prompt)
269 else: 265 else:
270 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 266 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
271 output_dir.mkdir(parents=True, exist_ok=True) 267 output_dir.mkdir(parents=True, exist_ok=True)
@@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args):
306 ) 302 )
307 303
308 seed = args.seed + i 304 seed = args.seed + i
305 prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt
309 generator = torch.Generator(device="cuda").manual_seed(seed) 306 generator = torch.Generator(device="cuda").manual_seed(seed)
310 images = pipeline( 307 images = pipeline(
311 prompt=args.prompt, 308 prompt=prompt,
312 negative_prompt=args.negative_prompt, 309 negative_prompt=args.negative_prompt,
313 height=args.height, 310 height=args.height,
314 width=args.width, 311 width=args.width,
@@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args):
321 ).images 318 ).images
322 319
323 for j, image in enumerate(images): 320 for j, image in enumerate(images):
321 basename = f"{seed}_{j // len(args.prompt)}"
324 dir = image_dir[j % len(args.prompt)] 322 dir = image_dir[j % len(args.prompt)]
325 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) 323
326 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) 324 image.save(dir.joinpath(f"{basename}.png"))
325 image.save(dir.joinpath(f"{basename}.jpg"), quality=85)
326 with open(dir.joinpath(f"{basename}.txt"), 'w') as f:
327 f.write(prompt[j % len(args.prompt)])
327 328
328 if torch.cuda.is_available(): 329 if torch.cuda.is_available():
329 torch.cuda.empty_cache() 330 torch.cuda.empty_cache()