From db46c9ead869c0713abc34ab6b9a0378d85fe7b2 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 28 Sep 2022 11:49:55 +0200
Subject: Infer script: Store args, better path/file names

---
 infer.py | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/infer.py b/infer.py
index 70da08f..f2007e9 100644
--- a/infer.py
+++ b/infer.py
@@ -74,15 +74,22 @@ def parse_args():
     return args
 
 
+def save_args(basepath, args, extra={}):
+    info = {"args": vars(args)}
+    info["args"].update(extra)
+    with open(f"{basepath}/args.json", "w") as f:
+        json.dump(info, f, indent=4)
+
+
 def main():
     args = parse_args()
 
     seed = args.seed or torch.random.seed()
-    generator = torch.Generator(device="cuda").manual_seed(seed)
 
     now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
-    output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}")
+    output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
     output_dir.mkdir(parents=True, exist_ok=True)
+    save_args(output_dir, args)
 
     tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
     text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
@@ -106,6 +113,7 @@ def main():
 
     with autocast("cuda"):
         for i in range(args.batch_num):
+            generator = torch.Generator(device="cuda").manual_seed(seed + i)
             images = pipeline(
                 [args.prompt] * args.batch_size,
                 num_inference_steps=args.steps,
@@ -114,7 +122,7 @@ def main():
             ).images
 
             for j, image in enumerate(images):
-                image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg"))
+                image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
 
 
 if __name__ == "__main__":
-- 
cgit v1.2.3-70-g09d2