summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
committerVolpeon <git@volpeon.ink>2022-12-25 23:50:24 +0100
commit7505f7e843dc719622a15f4ee301609813763d77 (patch)
treefe67640dce9fec4f625d6d1600c696cd7de006ee /infer.py
parentUpdate (diff)
downloadtextual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2
textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip
Code simplifications, avoid autocast
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py48
1 files changed, 24 insertions, 24 deletions
diff --git a/infer.py b/infer.py
index 420cb83..f566114 100644
--- a/infer.py
+++ b/infer.py
@@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype):
209 return pipeline 209 return pipeline
210 210
211 211
212@torch.inference_mode()
212def generate(output_dir, pipeline, args): 213def generate(output_dir, pipeline, args):
213 if isinstance(args.prompt, str): 214 if isinstance(args.prompt, str):
214 args.prompt = [args.prompt] 215 args.prompt = [args.prompt]
@@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args):
245 elif args.scheduler == "kdpm2_a": 246 elif args.scheduler == "kdpm2_a":
246 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 247 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
247 248
248 with torch.autocast("cuda"), torch.inference_mode(): 249 for i in range(args.batch_num):
249 for i in range(args.batch_num): 250 pipeline.set_progress_bar_config(
250 pipeline.set_progress_bar_config( 251 desc=f"Batch {i + 1} of {args.batch_num}",
251 desc=f"Batch {i + 1} of {args.batch_num}", 252 dynamic_ncols=True
252 dynamic_ncols=True 253 )
253 ) 254
254 255 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
255 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 256 images = pipeline(
256 images = pipeline( 257 prompt=args.prompt,
257 prompt=args.prompt, 258 negative_prompt=args.negative_prompt,
258 negative_prompt=args.negative_prompt, 259 height=args.height,
259 height=args.height, 260 width=args.width,
260 width=args.width, 261 num_images_per_prompt=args.batch_size,
261 num_images_per_prompt=args.batch_size, 262 num_inference_steps=args.steps,
262 num_inference_steps=args.steps, 263 guidance_scale=args.guidance_scale,
263 guidance_scale=args.guidance_scale, 264 generator=generator,
264 generator=generator, 265 image=init_image,
265 image=init_image, 266 strength=args.image_noise,
266 strength=args.image_noise, 267 ).images
267 ).images 268
268 269 for j, image in enumerate(images):
269 for j, image in enumerate(images): 270 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png"))
270 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) 271 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
271 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
272 272
273 if torch.cuda.is_available(): 273 if torch.cuda.is_available():
274 torch.cuda.empty_cache() 274 torch.cuda.empty_cache()