diff options
author | Volpeon <git@volpeon.ink> | 2022-10-01 16:53:19 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-01 16:53:19 +0200 |
commit | 6720c99f7082dc855059ad4afd6b3cb45b62bc1f (patch) | |
tree | d27f69880472df0cd6f63ea42bbf7a789ec5d0b7 /infer.py | |
parent | Made inference script interactive (diff) | |
download | textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.gz textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.bz2 textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.zip |
Fix seed, better progress bar, fix euler_a for batch size > 1
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 10 |
1 files changed, 7 insertions, 3 deletions
@@ -91,7 +91,7 @@ def create_cmd_parser(): | |||
91 | parser.add_argument( | 91 | parser.add_argument( |
92 | "--seed", | 92 | "--seed", |
93 | type=int, | 93 | type=int, |
94 | default=torch.random.seed(), | 94 | default=None, |
95 | ) | 95 | ) |
96 | parser.add_argument( | 96 | parser.add_argument( |
97 | "--config", | 97 | "--config", |
@@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args): | |||
167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") |
168 | output_dir.mkdir(parents=True, exist_ok=True) | 168 | output_dir.mkdir(parents=True, exist_ok=True) |
169 | 169 | ||
170 | seed = args.seed or torch.random.seed() | ||
171 | |||
170 | save_args(output_dir, args) | 172 | save_args(output_dir, args) |
171 | 173 | ||
172 | with autocast("cuda"): | 174 | with autocast("cuda"): |
173 | for i in range(args.batch_num): | 175 | for i in range(args.batch_num): |
174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") |
177 | |||
178 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | ||
175 | images = pipeline( | 179 | images = pipeline( |
176 | prompt=[args.prompt] * args.batch_size, | 180 | prompt=[args.prompt] * args.batch_size, |
177 | height=args.height, | 181 | height=args.height, |
@@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args): | |||
183 | ).images | 187 | ).images |
184 | 188 | ||
185 | for j, image in enumerate(images): | 189 | for j, image in enumerate(images): |
186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
187 | 191 | ||
188 | 192 | ||
189 | class CmdParse(cmd.Cmd): | 193 | class CmdParse(cmd.Cmd): |