From 7505f7e843dc719622a15f4ee301609813763d77 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 23:50:24 +0100 Subject: Code simplifications, avoid autocast --- infer.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) (limited to 'infer.py') diff --git a/infer.py b/infer.py index 420cb83..f566114 100644 --- a/infer.py +++ b/infer.py @@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): return pipeline +@torch.inference_mode() def generate(output_dir, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] @@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): elif args.scheduler == "kdpm2_a": pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) - with torch.autocast("cuda"), torch.inference_mode(): - for i in range(args.batch_num): - pipeline.set_progress_bar_config( - desc=f"Batch {i + 1} of {args.batch_num}", - dynamic_ncols=True - ) - - generator = torch.Generator(device="cuda").manual_seed(args.seed + i) - images = pipeline( - prompt=args.prompt, - negative_prompt=args.negative_prompt, - height=args.height, - width=args.width, - num_images_per_prompt=args.batch_size, - num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, - generator=generator, - image=init_image, - strength=args.image_noise, - ).images - - for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) + for i in range(args.batch_num): + pipeline.set_progress_bar_config( + desc=f"Batch {i + 1} of {args.batch_num}", + dynamic_ncols=True + ) + + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + images = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_images_per_prompt=args.batch_size, + num_inference_steps=args.steps, + guidance_scale=args.guidance_scale, + generator=generator, + image=init_image, + strength=args.image_noise, + ).images + + for j, image in enumerate(images): + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() -- cgit v1.2.3-54-g00ecf