diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
| commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
| tree | fe67640dce9fec4f625d6d1600c696cd7de006ee /infer.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip | |
Code simplifications, avoid autocast
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 44 |
1 files changed, 22 insertions, 22 deletions
| @@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
| 209 | return pipeline | 209 | return pipeline |
| 210 | 210 | ||
| 211 | 211 | ||
| 212 | @torch.inference_mode() | ||
| 212 | def generate(output_dir, pipeline, args): | 213 | def generate(output_dir, pipeline, args): |
| 213 | if isinstance(args.prompt, str): | 214 | if isinstance(args.prompt, str): |
| 214 | args.prompt = [args.prompt] | 215 | args.prompt = [args.prompt] |
| @@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): | |||
| 245 | elif args.scheduler == "kdpm2_a": | 246 | elif args.scheduler == "kdpm2_a": |
| 246 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 247 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 247 | 248 | ||
| 248 | with torch.autocast("cuda"), torch.inference_mode(): | 249 | for i in range(args.batch_num): |
| 249 | for i in range(args.batch_num): | 250 | pipeline.set_progress_bar_config( |
| 250 | pipeline.set_progress_bar_config( | 251 | desc=f"Batch {i + 1} of {args.batch_num}", |
| 251 | desc=f"Batch {i + 1} of {args.batch_num}", | 252 | dynamic_ncols=True |
| 252 | dynamic_ncols=True | 253 | ) |
| 253 | ) | ||
| 254 | 254 | ||
| 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| 256 | images = pipeline( | 256 | images = pipeline( |
| 257 | prompt=args.prompt, | 257 | prompt=args.prompt, |
| 258 | negative_prompt=args.negative_prompt, | 258 | negative_prompt=args.negative_prompt, |
| 259 | height=args.height, | 259 | height=args.height, |
| 260 | width=args.width, | 260 | width=args.width, |
| 261 | num_images_per_prompt=args.batch_size, | 261 | num_images_per_prompt=args.batch_size, |
| 262 | num_inference_steps=args.steps, | 262 | num_inference_steps=args.steps, |
| 263 | guidance_scale=args.guidance_scale, | 263 | guidance_scale=args.guidance_scale, |
| 264 | generator=generator, | 264 | generator=generator, |
| 265 | image=init_image, | 265 | image=init_image, |
| 266 | strength=args.image_noise, | 266 | strength=args.image_noise, |
| 267 | ).images | 267 | ).images |
| 268 | 268 | ||
| 269 | for j, image in enumerate(images): | 269 | for j, image in enumerate(images): |
| 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
| 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) |
| 272 | 272 | ||
| 273 | if torch.cuda.is_available(): | 273 | if torch.cuda.is_available(): |
| 274 | torch.cuda.empty_cache() | 274 | torch.cuda.empty_cache() |
