diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
commit | c90099f06e0b461660b326fb6d86b69d86e78289 (patch) | |
tree | df4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /textual_inversion.py | |
parent | Fixed euler_a generator argument (diff) | |
download | textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2 textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip |
Added negative prompt support for training scripts
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 83 |
1 files changed, 23 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -207,16 +207,10 @@ def parse_args(): | |||
207 | help="Size of sample images", | 207 | help="Size of sample images", |
208 | ) | 208 | ) |
209 | parser.add_argument( | 209 | parser.add_argument( |
210 | "--stable_sample_batches", | 210 | "--sample_batches", |
211 | type=int, | 211 | type=int, |
212 | default=1, | 212 | default=1, |
213 | help="Number of fixed seed sample batches to generate per checkpoint", | 213 | help="Number of sample batches to generate per checkpoint", |
214 | ) | ||
215 | parser.add_argument( | ||
216 | "--random_sample_batches", | ||
217 | type=int, | ||
218 | default=1, | ||
219 | help="Number of random seed sample batches to generate per checkpoint", | ||
220 | ) | 214 | ) |
221 | parser.add_argument( | 215 | parser.add_argument( |
222 | "--sample_batch_size", | 216 | "--sample_batch_size", |
@@ -319,9 +313,8 @@ class Checkpointer: | |||
319 | placeholder_token_id, | 313 | placeholder_token_id, |
320 | output_dir, | 314 | output_dir, |
321 | sample_image_size, | 315 | sample_image_size, |
322 | random_sample_batches, | 316 | sample_batches, |
323 | sample_batch_size, | 317 | sample_batch_size, |
324 | stable_sample_batches, | ||
325 | seed | 318 | seed |
326 | ): | 319 | ): |
327 | self.datamodule = datamodule | 320 | self.datamodule = datamodule |
@@ -334,9 +327,8 @@ class Checkpointer: | |||
334 | self.output_dir = output_dir | 327 | self.output_dir = output_dir |
335 | self.sample_image_size = sample_image_size | 328 | self.sample_image_size = sample_image_size |
336 | self.seed = seed | 329 | self.seed = seed |
337 | self.random_sample_batches = random_sample_batches | 330 | self.sample_batches = sample_batches |
338 | self.sample_batch_size = sample_batch_size | 331 | self.sample_batch_size = sample_batch_size |
339 | self.stable_sample_batches = stable_sample_batches | ||
340 | 332 | ||
341 | @torch.no_grad() | 333 | @torch.no_grad() |
342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
@@ -385,63 +377,33 @@ class Checkpointer: | |||
385 | train_data = self.datamodule.train_dataloader() | 377 | train_data = self.datamodule.train_dataloader() |
386 | val_data = self.datamodule.val_dataloader() | 378 | val_data = self.datamodule.val_dataloader() |
387 | 379 | ||
388 | if self.stable_sample_batches > 0: | 380 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
389 | stable_latents = torch.randn( | 381 | stable_latents = torch.randn( |
390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 382 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
391 | device=pipeline.device, | 383 | device=pipeline.device, |
392 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 384 | generator=generator, |
393 | ) | 385 | ) |
394 | |||
395 | all_samples = [] | ||
396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
398 | |||
399 | data_enum = enumerate(val_data) | ||
400 | |||
401 | # Generate and save stable samples | ||
402 | for i in range(0, self.stable_sample_batches): | ||
403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
405 | |||
406 | with self.accelerator.autocast(): | ||
407 | samples = pipeline( | ||
408 | prompt=prompt, | ||
409 | height=self.sample_image_size, | ||
410 | latents=stable_latents[:len(prompt)], | ||
411 | width=self.sample_image_size, | ||
412 | guidance_scale=guidance_scale, | ||
413 | eta=eta, | ||
414 | num_inference_steps=num_inference_steps, | ||
415 | output_type='pil' | ||
416 | )["sample"] | ||
417 | |||
418 | all_samples += samples | ||
419 | |||
420 | del samples | ||
421 | |||
422 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
423 | image_grid.save(file_path) | ||
424 | |||
425 | del all_samples | ||
426 | del image_grid | ||
427 | del stable_latents | ||
428 | 386 | ||
429 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 387 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
430 | all_samples = [] | 388 | all_samples = [] |
431 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 389 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
432 | file_path.parent.mkdir(parents=True, exist_ok=True) | 390 | file_path.parent.mkdir(parents=True, exist_ok=True) |
433 | 391 | ||
434 | data_enum = enumerate(data) | 392 | data_enum = enumerate(data) |
435 | 393 | ||
436 | for i in range(0, self.random_sample_batches): | 394 | for i in range(self.sample_batches): |
437 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 395 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
438 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 396 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
397 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
439 | 398 | ||
440 | with self.accelerator.autocast(): | 399 | with self.accelerator.autocast(): |
441 | samples = pipeline( | 400 | samples = pipeline( |
442 | prompt=prompt, | 401 | prompt=prompt, |
402 | negative_prompt=nprompt, | ||
443 | height=self.sample_image_size, | 403 | height=self.sample_image_size, |
444 | width=self.sample_image_size, | 404 | width=self.sample_image_size, |
405 | latents=latents[:len(prompt)] if latents is not None else None, | ||
406 | generator=generator if latents is not None else None, | ||
445 | guidance_scale=guidance_scale, | 407 | guidance_scale=guidance_scale, |
446 | eta=eta, | 408 | eta=eta, |
447 | num_inference_steps=num_inference_steps, | 409 | num_inference_steps=num_inference_steps, |
@@ -452,7 +414,7 @@ class Checkpointer: | |||
452 | 414 | ||
453 | del samples | 415 | del samples |
454 | 416 | ||
455 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 417 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
456 | image_grid.save(file_path) | 418 | image_grid.save(file_path) |
457 | 419 | ||
458 | del all_samples | 420 | del all_samples |
@@ -461,6 +423,8 @@ class Checkpointer: | |||
461 | del unwrapped | 423 | del unwrapped |
462 | del scheduler | 424 | del scheduler |
463 | del pipeline | 425 | del pipeline |
426 | del generator | ||
427 | del stable_latents | ||
464 | 428 | ||
465 | if torch.cuda.is_available(): | 429 | if torch.cuda.is_available(): |
466 | torch.cuda.empty_cache() | 430 | torch.cuda.empty_cache() |
@@ -603,7 +567,7 @@ def main(): | |||
603 | placeholder_token=args.placeholder_token, | 567 | placeholder_token=args.placeholder_token, |
604 | repeats=args.repeats, | 568 | repeats=args.repeats, |
605 | center_crop=args.center_crop, | 569 | center_crop=args.center_crop, |
606 | valid_set_size=args.sample_batch_size*args.stable_sample_batches | 570 | valid_set_size=args.sample_batch_size*args.sample_batches |
607 | ) | 571 | ) |
608 | 572 | ||
609 | datamodule.prepare_data() | 573 | datamodule.prepare_data() |
@@ -623,8 +587,7 @@ def main(): | |||
623 | output_dir=basepath, | 587 | output_dir=basepath, |
624 | sample_image_size=args.sample_image_size, | 588 | sample_image_size=args.sample_image_size, |
625 | sample_batch_size=args.sample_batch_size, | 589 | sample_batch_size=args.sample_batch_size, |
626 | random_sample_batches=args.random_sample_batches, | 590 | sample_batches=args.sample_batches, |
627 | stable_sample_batches=args.stable_sample_batches, | ||
628 | seed=args.seed or torch.random.seed() | 591 | seed=args.seed or torch.random.seed() |
629 | ) | 592 | ) |
630 | 593 | ||