From 6720c99f7082dc855059ad4afd6b3cb45b62bc1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Oct 2022 16:53:19 +0200 Subject: Fix seed, better progress bar, fix euler_a for batch size > 1 --- infer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 40720ea..d917239 100644 --- a/infer.py +++ b/infer.py @@ -91,7 +91,7 @@ def create_cmd_parser(): parser.add_argument( "--seed", type=int, - default=torch.random.seed(), + default=None, ) parser.add_argument( "--config", @@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args): output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") output_dir.mkdir(parents=True, exist_ok=True) + seed = args.seed or torch.random.seed() + save_args(output_dir, args) with autocast("cuda"): for i in range(args.batch_num): - generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") + + generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( prompt=[args.prompt] * args.batch_size, height=args.height, @@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) class CmdParse(cmd.Cmd): -- cgit v1.2.3-54-g00ecf