diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 48 |
1 files changed, 24 insertions, 24 deletions
@@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
209 | return pipeline | 209 | return pipeline |
210 | 210 | ||
211 | 211 | ||
212 | @torch.inference_mode() | ||
212 | def generate(output_dir, pipeline, args): | 213 | def generate(output_dir, pipeline, args): |
213 | if isinstance(args.prompt, str): | 214 | if isinstance(args.prompt, str): |
214 | args.prompt = [args.prompt] | 215 | args.prompt = [args.prompt] |
@@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): | |||
245 | elif args.scheduler == "kdpm2_a": | 246 | elif args.scheduler == "kdpm2_a": |
246 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 247 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
247 | 248 | ||
248 | with torch.autocast("cuda"), torch.inference_mode(): | 249 | for i in range(args.batch_num): |
249 | for i in range(args.batch_num): | 250 | pipeline.set_progress_bar_config( |
250 | pipeline.set_progress_bar_config( | 251 | desc=f"Batch {i + 1} of {args.batch_num}", |
251 | desc=f"Batch {i + 1} of {args.batch_num}", | 252 | dynamic_ncols=True |
252 | dynamic_ncols=True | 253 | ) |
253 | ) | 254 | |
254 | 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | |
255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 256 | images = pipeline( |
256 | images = pipeline( | 257 | prompt=args.prompt, |
257 | prompt=args.prompt, | 258 | negative_prompt=args.negative_prompt, |
258 | negative_prompt=args.negative_prompt, | 259 | height=args.height, |
259 | height=args.height, | 260 | width=args.width, |
260 | width=args.width, | 261 | num_images_per_prompt=args.batch_size, |
261 | num_images_per_prompt=args.batch_size, | 262 | num_inference_steps=args.steps, |
262 | num_inference_steps=args.steps, | 263 | guidance_scale=args.guidance_scale, |
263 | guidance_scale=args.guidance_scale, | 264 | generator=generator, |
264 | generator=generator, | 265 | image=init_image, |
265 | image=init_image, | 266 | strength=args.image_noise, |
266 | strength=args.image_noise, | 267 | ).images |
267 | ).images | 268 | |
268 | 269 | for j, image in enumerate(images): | |
269 | for j, image in enumerate(images): | 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) |
271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | ||
272 | 272 | ||
273 | if torch.cuda.is_available(): | 273 | if torch.cuda.is_available(): |
274 | torch.cuda.empty_cache() | 274 | torch.cuda.empty_cache() |