diff options
-rw-r--r-- | infer.py | 14 |
1 files changed, 11 insertions, 3 deletions
@@ -74,15 +74,22 @@ def parse_args(): | |||
74 | return args | 74 | return args |
75 | 75 | ||
76 | 76 | ||
77 | def save_args(basepath, args, extra={}): | ||
78 | info = {"args": vars(args)} | ||
79 | info["args"].update(extra) | ||
80 | with open(f"{basepath}/args.json", "w") as f: | ||
81 | json.dump(info, f, indent=4) | ||
82 | |||
83 | |||
77 | def main(): | 84 | def main(): |
78 | args = parse_args() | 85 | args = parse_args() |
79 | 86 | ||
80 | seed = args.seed or torch.random.seed() | 87 | seed = args.seed or torch.random.seed() |
81 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
82 | 88 | ||
83 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
84 | output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") | 90 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") |
85 | output_dir.mkdir(parents=True, exist_ok=True) | 91 | output_dir.mkdir(parents=True, exist_ok=True) |
92 | save_args(output_dir, args) | ||
86 | 93 | ||
87 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 94 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) |
88 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | 95 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) |
@@ -106,6 +113,7 @@ def main(): | |||
106 | 113 | ||
107 | with autocast("cuda"): | 114 | with autocast("cuda"): |
108 | for i in range(args.batch_num): | 115 | for i in range(args.batch_num): |
116 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | ||
109 | images = pipeline( | 117 | images = pipeline( |
110 | [args.prompt] * args.batch_size, | 118 | [args.prompt] * args.batch_size, |
111 | num_inference_steps=args.steps, | 119 | num_inference_steps=args.steps, |
@@ -114,7 +122,7 @@ def main(): | |||
114 | ).images | 122 | ).images |
115 | 123 | ||
116 | for j, image in enumerate(images): | 124 | for j, image in enumerate(images): |
117 | image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) | 125 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
118 | 126 | ||
119 | 127 | ||
120 | if __name__ == "__main__": | 128 | if __name__ == "__main__": |