diff options
author | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
commit | 3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch) | |
tree | e652a54e6c241eef52ddb30f2d7048da8f306f7b /infer.py | |
parent | Update (diff) | |
download | textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2 textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip |
Added progressive aspect ratio bucketing
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 23 |
1 files changed, 12 insertions, 11 deletions
@@ -238,16 +238,15 @@ def create_pipeline(model, dtype): | |||
238 | return pipeline | 238 | return pipeline |
239 | 239 | ||
240 | 240 | ||
241 | def shuffle_prompts(prompts: list[str]) -> list[str]: | ||
242 | return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] | ||
243 | |||
244 | |||
241 | @torch.inference_mode() | 245 | @torch.inference_mode() |
242 | def generate(output_dir: Path, pipeline, args): | 246 | def generate(output_dir: Path, pipeline, args): |
243 | if isinstance(args.prompt, str): | 247 | if isinstance(args.prompt, str): |
244 | args.prompt = [args.prompt] | 248 | args.prompt = [args.prompt] |
245 | 249 | ||
246 | if args.shuffle: | ||
247 | args.prompt *= args.batch_size | ||
248 | args.batch_size = 1 | ||
249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
250 | |||
251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | 250 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] |
252 | 251 | ||
253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 252 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
@@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): | |||
263 | dir = output_dir.joinpath(slugify(prompt)[:100]) | 262 | dir = output_dir.joinpath(slugify(prompt)[:100]) |
264 | dir.mkdir(parents=True, exist_ok=True) | 263 | dir.mkdir(parents=True, exist_ok=True) |
265 | image_dir.append(dir) | 264 | image_dir.append(dir) |
266 | |||
267 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
268 | f.write(prompt) | ||
269 | else: | 265 | else: |
270 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 266 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
271 | output_dir.mkdir(parents=True, exist_ok=True) | 267 | output_dir.mkdir(parents=True, exist_ok=True) |
@@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): | |||
306 | ) | 302 | ) |
307 | 303 | ||
308 | seed = args.seed + i | 304 | seed = args.seed + i |
305 | prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt | ||
309 | generator = torch.Generator(device="cuda").manual_seed(seed) | 306 | generator = torch.Generator(device="cuda").manual_seed(seed) |
310 | images = pipeline( | 307 | images = pipeline( |
311 | prompt=args.prompt, | 308 | prompt=prompt, |
312 | negative_prompt=args.negative_prompt, | 309 | negative_prompt=args.negative_prompt, |
313 | height=args.height, | 310 | height=args.height, |
314 | width=args.width, | 311 | width=args.width, |
@@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): | |||
321 | ).images | 318 | ).images |
322 | 319 | ||
323 | for j, image in enumerate(images): | 320 | for j, image in enumerate(images): |
321 | basename = f"{seed}_{j // len(args.prompt)}" | ||
324 | dir = image_dir[j % len(args.prompt)] | 322 | dir = image_dir[j % len(args.prompt)] |
325 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | 323 | |
326 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | 324 | image.save(dir.joinpath(f"{basename}.png")) |
325 | image.save(dir.joinpath(f"{basename}.jpg"), quality=85) | ||
326 | with open(dir.joinpath(f"{basename}.txt"), 'w') as f: | ||
327 | f.write(prompt[j % len(args.prompt)]) | ||
327 | 328 | ||
328 | if torch.cuda.is_available(): | 329 | if torch.cuda.is_available(): |
329 | torch.cuda.empty_cache() | 330 | torch.cuda.empty_cache() |