summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-14 18:41:38 +0100
committerVolpeon <git@volpeon.ink>2022-11-14 18:41:38 +0100
commit8ff51a771905d0d14a3c690f54eb644515730348 (patch)
treef1096181e912291f85d82d95af88a9f4257c1b35 /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.gz
textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.tar.bz2
textual-inversion-diff-8ff51a771905d0d14a3c690f54eb644515730348.zip
Refactoring
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 999161b..bf591bc 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -437,7 +437,7 @@ class Checkpointer:
437 generator=generator, 437 generator=generator,
438 ) 438 )
439 439
440 with torch.inference_mode(): 440 with torch.autocast("cuda"), torch.inference_mode():
441 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 441 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
442 all_samples = [] 442 all_samples = []
443 file_path = samples_path.joinpath(pool, f"step_{step}.png") 443 file_path = samples_path.joinpath(pool, f"step_{step}.png")
@@ -456,7 +456,7 @@ class Checkpointer:
456 negative_prompt=nprompt, 456 negative_prompt=nprompt,
457 height=self.sample_image_size, 457 height=self.sample_image_size,
458 width=self.sample_image_size, 458 width=self.sample_image_size,
459 latents=latents[:len(prompt)] if latents is not None else None, 459 latents_or_image=latents[:len(prompt)] if latents is not None else None,
460 generator=generator if latents is not None else None, 460 generator=generator if latents is not None else None,
461 guidance_scale=guidance_scale, 461 guidance_scale=guidance_scale,
462 eta=eta, 462 eta=eta,
@@ -670,7 +670,7 @@ def main():
670 ).to(accelerator.device) 670 ).to(accelerator.device)
671 pipeline.set_progress_bar_config(dynamic_ncols=True) 671 pipeline.set_progress_bar_config(dynamic_ncols=True)
672 672
673 with torch.inference_mode(): 673 with torch.autocast("cuda"), torch.inference_mode():
674 for batch in batched_data: 674 for batch in batched_data:
675 image_name = [p.class_image_path for p in batch] 675 image_name = [p.class_image_path for p in batch]
676 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] 676 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch]
@@ -908,7 +908,7 @@ def main():
908 text_encoder.eval() 908 text_encoder.eval()
909 val_loss = 0.0 909 val_loss = 0.0
910 910
911 with torch.inference_mode(): 911 with torch.autocast("cuda"), torch.inference_mode():
912 for step, batch in enumerate(val_dataloader): 912 for step, batch in enumerate(val_dataloader):
913 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 913 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
914 latents = latents * 0.18215 914 latents = latents * 0.18215