diff options
Diffstat (limited to 'infer.py')
-rw-r--r-- | infer.py | 56 |
1 files changed, 44 insertions, 12 deletions
@@ -45,6 +45,7 @@ default_args = { | |||
45 | 45 | ||
46 | 46 | ||
47 | default_cmds = { | 47 | default_cmds = { |
48 | "project": "", | ||
48 | "scheduler": "dpmsm", | 49 | "scheduler": "dpmsm", |
49 | "prompt": None, | 50 | "prompt": None, |
50 | "negative_prompt": None, | 51 | "negative_prompt": None, |
@@ -104,6 +105,12 @@ def create_cmd_parser(): | |||
104 | description="Simple example of a training script." | 105 | description="Simple example of a training script." |
105 | ) | 106 | ) |
106 | parser.add_argument( | 107 | parser.add_argument( |
108 | "--project", | ||
109 | type=str, | ||
110 | default=None, | ||
111 | help="The name of the current project.", | ||
112 | ) | ||
113 | parser.add_argument( | ||
107 | "--scheduler", | 114 | "--scheduler", |
108 | type=str, | 115 | type=str, |
109 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 116 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
@@ -184,7 +191,16 @@ def save_args(basepath, args, extra={}): | |||
184 | json.dump(info, f, indent=4) | 191 | json.dump(info, f, indent=4) |
185 | 192 | ||
186 | 193 | ||
187 | def create_pipeline(model, embeddings_dir, dtype): | 194 | def load_embeddings(pipeline, embeddings_dir): |
195 | added_tokens = load_embeddings_from_dir( | ||
196 | pipeline.tokenizer, | ||
197 | pipeline.text_encoder.text_model.embeddings, | ||
198 | Path(embeddings_dir) | ||
199 | ) | ||
200 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
201 | |||
202 | |||
203 | def create_pipeline(model, dtype): | ||
188 | print("Loading Stable Diffusion pipeline...") | 204 | print("Loading Stable Diffusion pipeline...") |
189 | 205 | ||
190 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 206 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
@@ -193,10 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
193 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 209 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
194 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 210 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
195 | 211 | ||
196 | embeddings = patch_managed_embeddings(text_encoder) | 212 | patch_managed_embeddings(text_encoder) |
197 | added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) | ||
198 | |||
199 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
200 | 213 | ||
201 | pipeline = VlpnStableDiffusion( | 214 | pipeline = VlpnStableDiffusion( |
202 | text_encoder=text_encoder, | 215 | text_encoder=text_encoder, |
@@ -220,7 +233,14 @@ def generate(output_dir, pipeline, args): | |||
220 | args.prompt = [args.prompt] | 233 | args.prompt = [args.prompt] |
221 | 234 | ||
222 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 235 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
223 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 236 | use_subdirs = len(args.prompt) != 1 |
237 | if use_subdirs: | ||
238 | if len(args.project) != 0: | ||
239 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") | ||
240 | else: | ||
241 | output_dir = output_dir.joinpath(now) | ||
242 | else: | ||
243 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | ||
224 | output_dir.mkdir(parents=True, exist_ok=True) | 244 | output_dir.mkdir(parents=True, exist_ok=True) |
225 | 245 | ||
226 | args.seed = args.seed or torch.random.seed() | 246 | args.seed = args.seed or torch.random.seed() |
@@ -257,7 +277,8 @@ def generate(output_dir, pipeline, args): | |||
257 | dynamic_ncols=True | 277 | dynamic_ncols=True |
258 | ) | 278 | ) |
259 | 279 | ||
260 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 280 | seed = args.seed + i |
281 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
261 | images = pipeline( | 282 | images = pipeline( |
262 | prompt=args.prompt, | 283 | prompt=args.prompt, |
263 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |
@@ -272,8 +293,13 @@ def generate(output_dir, pipeline, args): | |||
272 | ).images | 293 | ).images |
273 | 294 | ||
274 | for j, image in enumerate(images): | 295 | for j, image in enumerate(images): |
275 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 296 | image_dir = output_dir |
276 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | 297 | if use_subdirs: |
298 | idx = j % len(args.prompt) | ||
299 | image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100]) | ||
300 | image_dir.mkdir(parents=True, exist_ok=True) | ||
301 | image.save(image_dir.joinpath(f"{seed}_{j}.png")) | ||
302 | image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85) | ||
277 | 303 | ||
278 | if torch.cuda.is_available(): | 304 | if torch.cuda.is_available(): |
279 | torch.cuda.empty_cache() | 305 | torch.cuda.empty_cache() |
@@ -283,10 +309,11 @@ class CmdParse(cmd.Cmd): | |||
283 | prompt = 'dream> ' | 309 | prompt = 'dream> ' |
284 | commands = [] | 310 | commands = [] |
285 | 311 | ||
286 | def __init__(self, output_dir, pipeline, parser): | 312 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): |
287 | super().__init__() | 313 | super().__init__() |
288 | 314 | ||
289 | self.output_dir = output_dir | 315 | self.output_dir = output_dir |
316 | self.ti_embeddings_dir = ti_embeddings_dir | ||
290 | self.pipeline = pipeline | 317 | self.pipeline = pipeline |
291 | self.parser = parser | 318 | self.parser = parser |
292 | 319 | ||
@@ -302,6 +329,10 @@ class CmdParse(cmd.Cmd): | |||
302 | if elements[0] == 'q': | 329 | if elements[0] == 'q': |
303 | return True | 330 | return True |
304 | 331 | ||
332 | if elements[0] == 'reload_embeddings': | ||
333 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | ||
334 | return | ||
335 | |||
305 | try: | 336 | try: |
306 | args = run_parser(self.parser, default_cmds, elements) | 337 | args = run_parser(self.parser, default_cmds, elements) |
307 | 338 | ||
@@ -337,9 +368,10 @@ def main(): | |||
337 | output_dir = Path(args.output_dir) | 368 | output_dir = Path(args.output_dir) |
338 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 369 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
339 | 370 | ||
340 | pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) | 371 | pipeline = create_pipeline(args.model, dtype) |
372 | load_embeddings(pipeline, args.ti_embeddings_dir) | ||
341 | cmd_parser = create_cmd_parser() | 373 | cmd_parser = create_cmd_parser() |
342 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 374 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) |
343 | cmd_prompt.cmdloop() | 375 | cmd_prompt.cmdloop() |
344 | 376 | ||
345 | 377 | ||