From ca914af018632b6231fb3ee4fcd5cdbdc467c784 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 21 Oct 2022 09:50:46 +0200 Subject: Add optional TI functionality to Dreambooth --- infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 8e17c4e..01010eb 100644 --- a/infer.py +++ b/infer.py @@ -258,7 +258,7 @@ def generate(output_dir, pipeline, args): output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) - seed = args.seed or torch.random.seed() + args.seed = args.seed or torch.random.seed() save_args(output_dir, args) @@ -276,7 +276,7 @@ def generate(output_dir, pipeline, args): dynamic_ncols=True ) - generator = torch.Generator(device="cuda").manual_seed(seed + i) + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) images = pipeline( prompt=args.prompt * (args.batch_size // len(args.prompt)), height=args.height, @@ -290,7 +290,7 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) if torch.cuda.is_available(): torch.cuda.empty_cache() -- cgit v1.2.3-70-g09d2