diff options
| -rw-r--r-- | infer.py | 14 |
1 files changed, 11 insertions, 3 deletions
| @@ -74,15 +74,22 @@ def parse_args(): | |||
| 74 | return args | 74 | return args |
| 75 | 75 | ||
| 76 | 76 | ||
| 77 | def save_args(basepath, args, extra={}): | ||
| 78 | info = {"args": vars(args)} | ||
| 79 | info["args"].update(extra) | ||
| 80 | with open(f"{basepath}/args.json", "w") as f: | ||
| 81 | json.dump(info, f, indent=4) | ||
| 82 | |||
| 83 | |||
| 77 | def main(): | 84 | def main(): |
| 78 | args = parse_args() | 85 | args = parse_args() |
| 79 | 86 | ||
| 80 | seed = args.seed or torch.random.seed() | 87 | seed = args.seed or torch.random.seed() |
| 81 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
| 82 | 88 | ||
| 83 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 84 | output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") | 90 | output_dir = Path(args.output_dir).joinpath(f"{now}_{slugify(args.prompt)[:100]}") |
| 85 | output_dir.mkdir(parents=True, exist_ok=True) | 91 | output_dir.mkdir(parents=True, exist_ok=True) |
| 92 | save_args(output_dir, args) | ||
| 86 | 93 | ||
| 87 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | 94 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) |
| 88 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | 95 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) |
| @@ -106,6 +113,7 @@ def main(): | |||
| 106 | 113 | ||
| 107 | with autocast("cuda"): | 114 | with autocast("cuda"): |
| 108 | for i in range(args.batch_num): | 115 | for i in range(args.batch_num): |
| 116 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | ||
| 109 | images = pipeline( | 117 | images = pipeline( |
| 110 | [args.prompt] * args.batch_size, | 118 | [args.prompt] * args.batch_size, |
| 111 | num_inference_steps=args.steps, | 119 | num_inference_steps=args.steps, |
| @@ -114,7 +122,7 @@ def main(): | |||
| 114 | ).images | 122 | ).images |
| 115 | 123 | ||
| 116 | for j, image in enumerate(images): | 124 | for j, image in enumerate(images): |
| 117 | image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) | 125 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
| 118 | 126 | ||
| 119 | 127 | ||
| 120 | if __name__ == "__main__": | 128 | if __name__ == "__main__": |
