From db46c9ead869c0713abc34ab6b9a0378d85fe7b2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 11:49:55 +0200 Subject: Infer script: Store args, better path/file names --- infer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 70da08f..f2007e9 100644 --- a/infer.py +++ b/infer.py @@ -74,15 +74,22 @@ def parse_args(): return args +def save_args(basepath, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(f"{basepath}/args.json", "w") as f: + json.dump(info, f, indent=4) + + def main(): args = parse_args() seed = args.seed or torch.random.seed() - generator = torch.Generator(device="cuda").manual_seed(seed) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") + output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") output_dir.mkdir(parents=True, exist_ok=True) + save_args(output_dir, args) tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) @@ -106,6 +113,7 @@ def main(): with autocast("cuda"): for i in range(args.batch_num): + generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( [args.prompt] * args.batch_size, num_inference_steps=args.steps, @@ -114,7 +122,7 @@ def main(): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) + image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf