summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-01 16:53:19 +0200
committerVolpeon <git@volpeon.ink>2022-10-01 16:53:19 +0200
commit6720c99f7082dc855059ad4afd6b3cb45b62bc1f (patch)
treed27f69880472df0cd6f63ea42bbf7a789ec5d0b7 /infer.py
parentMade inference script interactive (diff)
downloadtextual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.gz
textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.bz2
textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.zip
Fix seed, better progress bar, fix euler_a for batch size > 1
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index 40720ea..d917239 100644
--- a/infer.py
+++ b/infer.py
@@ -91,7 +91,7 @@ def create_cmd_parser():
91 parser.add_argument( 91 parser.add_argument(
92 "--seed", 92 "--seed",
93 type=int, 93 type=int,
94 default=torch.random.seed(), 94 default=None,
95 ) 95 )
96 parser.add_argument( 96 parser.add_argument(
97 "--config", 97 "--config",
@@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args):
167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") 167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True) 168 output_dir.mkdir(parents=True, exist_ok=True)
169 169
170 seed = args.seed or torch.random.seed()
171
170 save_args(output_dir, args) 172 save_args(output_dir, args)
171 173
172 with autocast("cuda"): 174 with autocast("cuda"):
173 for i in range(args.batch_num): 175 for i in range(args.batch_num):
174 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 176 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}")
177
178 generator = torch.Generator(device="cuda").manual_seed(seed + i)
175 images = pipeline( 179 images = pipeline(
176 prompt=[args.prompt] * args.batch_size, 180 prompt=[args.prompt] * args.batch_size,
177 height=args.height, 181 height=args.height,
@@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args):
183 ).images 187 ).images
184 188
185 for j, image in enumerate(images): 189 for j, image in enumerate(images):
186 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 190 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
187 191
188 192
189class CmdParse(cmd.Cmd): 193class CmdParse(cmd.Cmd):