summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index 70da08f..f2007e9 100644
--- a/infer.py
+++ b/infer.py
@@ -74,15 +74,22 @@ def parse_args():
74 return args 74 return args
75 75
76 76
77def save_args(basepath, args, extra={}):
78 info = {"args": vars(args)}
79 info["args"].update(extra)
80 with open(f"{basepath}/args.json", "w") as f:
81 json.dump(info, f, indent=4)
82
83
77def main(): 84def main():
78 args = parse_args() 85 args = parse_args()
79 86
80 seed = args.seed or torch.random.seed() 87 seed = args.seed or torch.random.seed()
81 generator = torch.Generator(device="cuda").manual_seed(seed)
82 88
83 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
84 output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") 90 output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}")
85 output_dir.mkdir(parents=True, exist_ok=True) 91 output_dir.mkdir(parents=True, exist_ok=True)
92 save_args(output_dir, args)
86 93
87 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) 94 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
88 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) 95 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
@@ -106,6 +113,7 @@ def main():
106 113
107 with autocast("cuda"): 114 with autocast("cuda"):
108 for i in range(args.batch_num): 115 for i in range(args.batch_num):
116 generator = torch.Generator(device="cuda").manual_seed(seed + i)
109 images = pipeline( 117 images = pipeline(
110 [args.prompt] * args.batch_size, 118 [args.prompt] * args.batch_size,
111 num_inference_steps=args.steps, 119 num_inference_steps=args.steps,
@@ -114,7 +122,7 @@ def main():
114 ).images 122 ).images
115 123
116 for j, image in enumerate(images): 124 for j, image in enumerate(images):
117 image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) 125 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
118 126
119 127
120if __name__ == "__main__": 128if __name__ == "__main__":