diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 57 |
1 files changed, 28 insertions, 29 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -4,6 +4,7 @@ import math | |||
4 | import os | 4 | import os |
5 | import datetime | 5 | import datetime |
6 | import logging | 6 | import logging |
7 | import json | ||
7 | from pathlib import Path | 8 | from pathlib import Path |
8 | 9 | ||
9 | import numpy as np | 10 | import numpy as np |
@@ -22,8 +23,6 @@ from tqdm.auto import tqdm | |||
22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 24 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | import json | ||
26 | import os | ||
27 | 26 | ||
28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
29 | 28 | ||
@@ -70,7 +69,7 @@ def parse_args(): | |||
70 | parser.add_argument( | 69 | parser.add_argument( |
71 | "--num_class_images", | 70 | "--num_class_images", |
72 | type=int, | 71 | type=int, |
73 | default=4, | 72 | default=200, |
74 | help="How many class images to generate per training image." | 73 | help="How many class images to generate per training image." |
75 | ) | 74 | ) |
76 | parser.add_argument( | 75 | parser.add_argument( |
@@ -141,7 +140,7 @@ def parse_args(): | |||
141 | parser.add_argument( | 140 | parser.add_argument( |
142 | "--lr_scheduler", | 141 | "--lr_scheduler", |
143 | type=str, | 142 | type=str, |
144 | default="constant", | 143 | default="linear", |
145 | help=( | 144 | help=( |
146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
147 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
@@ -402,20 +401,20 @@ class Checkpointer: | |||
402 | generator=generator, | 401 | generator=generator, |
403 | ) | 402 | ) |
404 | 403 | ||
405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 404 | with torch.inference_mode(): |
406 | all_samples = [] | 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 406 | all_samples = [] |
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
409 | 409 | ||
410 | data_enum = enumerate(data) | 410 | data_enum = enumerate(data) |
411 | 411 | ||
412 | for i in range(self.sample_batches): | 412 | for i in range(self.sample_batches): |
413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
414 | prompt = [prompt.format(self.placeholder_token) | 414 | prompt = [prompt.format(self.placeholder_token) |
415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
417 | 417 | ||
418 | with self.accelerator.autocast(): | ||
419 | samples = pipeline( | 418 | samples = pipeline( |
420 | prompt=prompt, | 419 | prompt=prompt, |
421 | negative_prompt=nprompt, | 420 | negative_prompt=nprompt, |
@@ -429,15 +428,15 @@ class Checkpointer: | |||
429 | output_type='pil' | 428 | output_type='pil' |
430 | )["sample"] | 429 | )["sample"] |
431 | 430 | ||
432 | all_samples += samples | 431 | all_samples += samples |
433 | 432 | ||
434 | del samples | 433 | del samples |
435 | 434 | ||
436 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 435 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
437 | image_grid.save(file_path) | 436 | image_grid.save(file_path) |
438 | 437 | ||
439 | del all_samples | 438 | del all_samples |
440 | del image_grid | 439 | del image_grid |
441 | 440 | ||
442 | del unwrapped | 441 | del unwrapped |
443 | del scheduler | 442 | del scheduler |
@@ -623,7 +622,7 @@ def main(): | |||
623 | datamodule.setup() | 622 | datamodule.setup() |
624 | 623 | ||
625 | if args.num_class_images != 0: | 624 | if args.num_class_images != 0: |
626 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 625 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
627 | 626 | ||
628 | if len(missing_data) != 0: | 627 | if len(missing_data) != 0: |
629 | batched_data = [missing_data[i:i+args.sample_batch_size] | 628 | batched_data = [missing_data[i:i+args.sample_batch_size] |
@@ -643,20 +642,20 @@ def main(): | |||
643 | pipeline.enable_attention_slicing() | 642 | pipeline.enable_attention_slicing() |
644 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
645 | 644 | ||
646 | for batch in batched_data: | 645 | with torch.inference_mode(): |
647 | image_name = [p[1] for p in batch] | 646 | for batch in batched_data: |
648 | prompt = [p[2].format(args.initializer_token) for p in batch] | 647 | image_name = [p.class_image_path for p in batch] |
649 | nprompt = [p[3] for p in batch] | 648 | prompt = [p.prompt.format(args.initializer_token) for p in batch] |
649 | nprompt = [p.nprompt for p in batch] | ||
650 | 650 | ||
651 | with accelerator.autocast(): | ||
652 | images = pipeline( | 651 | images = pipeline( |
653 | prompt=prompt, | 652 | prompt=prompt, |
654 | negative_prompt=nprompt, | 653 | negative_prompt=nprompt, |
655 | num_inference_steps=args.sample_steps | 654 | num_inference_steps=args.sample_steps |
656 | ).images | 655 | ).images |
657 | 656 | ||
658 | for i, image in enumerate(images): | 657 | for i, image in enumerate(images): |
659 | image.save(image_name[i]) | 658 | image.save(image_name[i]) |
660 | 659 | ||
661 | del pipeline | 660 | del pipeline |
662 | 661 | ||