diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 40 |
1 files changed, 29 insertions, 11 deletions
@@ -25,6 +25,7 @@ from diffusers import ( | |||
25 | ) | 25 | ) |
26 | from transformers import CLIPTextModel | 26 | from transformers import CLIPTextModel |
27 | 27 | ||
28 | from data.keywords import prompt_to_keywords, keywords_to_prompt | ||
28 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
29 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
30 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 31 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
@@ -49,6 +50,7 @@ default_cmds = { | |||
49 | "scheduler": "dpmsm", | 50 | "scheduler": "dpmsm", |
50 | "prompt": None, | 51 | "prompt": None, |
51 | "negative_prompt": None, | 52 | "negative_prompt": None, |
53 | "shuffle": True, | ||
52 | "image": None, | 54 | "image": None, |
53 | "image_noise": .7, | 55 | "image_noise": .7, |
54 | "width": 768, | 56 | "width": 768, |
@@ -126,6 +128,10 @@ def create_cmd_parser(): | |||
126 | nargs="*", | 128 | nargs="*", |
127 | ) | 129 | ) |
128 | parser.add_argument( | 130 | parser.add_argument( |
131 | "--shuffle", | ||
132 | type=bool, | ||
133 | ) | ||
134 | parser.add_argument( | ||
129 | "--image", | 135 | "--image", |
130 | type=str, | 136 | type=str, |
131 | ) | 137 | ) |
@@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir): | |||
197 | pipeline.text_encoder.text_model.embeddings, | 203 | pipeline.text_encoder.text_model.embeddings, |
198 | Path(embeddings_dir) | 204 | Path(embeddings_dir) |
199 | ) | 205 | ) |
200 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") | 206 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
201 | 207 | ||
202 | 208 | ||
203 | def create_pipeline(model, dtype): | 209 | def create_pipeline(model, dtype): |
@@ -228,20 +234,35 @@ def create_pipeline(model, dtype): | |||
228 | 234 | ||
229 | 235 | ||
230 | @torch.inference_mode() | 236 | @torch.inference_mode() |
231 | def generate(output_dir, pipeline, args): | 237 | def generate(output_dir: Path, pipeline, args): |
232 | if isinstance(args.prompt, str): | 238 | if isinstance(args.prompt, str): |
233 | args.prompt = [args.prompt] | 239 | args.prompt = [args.prompt] |
234 | 240 | ||
241 | if args.shuffle: | ||
242 | args.prompt *= args.batch_size | ||
243 | args.batch_size = 1 | ||
244 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
245 | |||
235 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 246 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
236 | use_subdirs = len(args.prompt) != 1 | 247 | image_dir = [] |
237 | if use_subdirs: | 248 | |
249 | if len(args.prompt) != 1: | ||
238 | if len(args.project) != 0: | 250 | if len(args.project) != 0: |
239 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") | 251 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") |
240 | else: | 252 | else: |
241 | output_dir = output_dir.joinpath(now) | 253 | output_dir = output_dir.joinpath(now) |
254 | |||
255 | for prompt in args.prompt: | ||
256 | dir = output_dir.joinpath(slugify(prompt)[:100]) | ||
257 | dir.mkdir(parents=True, exist_ok=True) | ||
258 | image_dir.append(dir) | ||
259 | |||
260 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
261 | f.write(prompt) | ||
242 | else: | 262 | else: |
243 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 263 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
244 | output_dir.mkdir(parents=True, exist_ok=True) | 264 | output_dir.mkdir(parents=True, exist_ok=True) |
265 | image_dir.append(output_dir) | ||
245 | 266 | ||
246 | args.seed = args.seed or torch.random.seed() | 267 | args.seed = args.seed or torch.random.seed() |
247 | 268 | ||
@@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args): | |||
293 | ).images | 314 | ).images |
294 | 315 | ||
295 | for j, image in enumerate(images): | 316 | for j, image in enumerate(images): |
296 | image_dir = output_dir | 317 | dir = image_dir[j % len(args.prompt)] |
297 | if use_subdirs: | 318 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) |
298 | image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) | 319 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) |
299 | image_dir.mkdir(parents=True, exist_ok=True) | ||
300 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | ||
301 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | ||
302 | 320 | ||
303 | if torch.cuda.is_available(): | 321 | if torch.cuda.is_available(): |
304 | torch.cuda.empty_cache() | 322 | torch.cuda.empty_cache() |