diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 57 |
1 files changed, 28 insertions, 29 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -4,6 +4,7 @@ import math | |||
| 4 | import os | 4 | import os |
| 5 | import datetime | 5 | import datetime |
| 6 | import logging | 6 | import logging |
| 7 | import json | ||
| 7 | from pathlib import Path | 8 | from pathlib import Path |
| 8 | 9 | ||
| 9 | import numpy as np | 10 | import numpy as np |
| @@ -22,8 +23,6 @@ from tqdm.auto import tqdm | |||
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 24 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | import json | ||
| 26 | import os | ||
| 27 | 26 | ||
| 28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 29 | 28 | ||
| @@ -70,7 +69,7 @@ def parse_args(): | |||
| 70 | parser.add_argument( | 69 | parser.add_argument( |
| 71 | "--num_class_images", | 70 | "--num_class_images", |
| 72 | type=int, | 71 | type=int, |
| 73 | default=4, | 72 | default=200, |
| 74 | help="How many class images to generate per training image." | 73 | help="How many class images to generate per training image." |
| 75 | ) | 74 | ) |
| 76 | parser.add_argument( | 75 | parser.add_argument( |
| @@ -141,7 +140,7 @@ def parse_args(): | |||
| 141 | parser.add_argument( | 140 | parser.add_argument( |
| 142 | "--lr_scheduler", | 141 | "--lr_scheduler", |
| 143 | type=str, | 142 | type=str, |
| 144 | default="constant", | 143 | default="linear", |
| 145 | help=( | 144 | help=( |
| 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 147 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
| @@ -402,20 +401,20 @@ class Checkpointer: | |||
| 402 | generator=generator, | 401 | generator=generator, |
| 403 | ) | 402 | ) |
| 404 | 403 | ||
| 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 404 | with torch.inference_mode(): |
| 406 | all_samples = [] | 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 406 | all_samples = [] |
| 408 | file_path.parent.mkdir(parents=True, exist_ok=True) | 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 409 | 409 | ||
| 410 | data_enum = enumerate(data) | 410 | data_enum = enumerate(data) |
| 411 | 411 | ||
| 412 | for i in range(self.sample_batches): | 412 | for i in range(self.sample_batches): |
| 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 414 | prompt = [prompt.format(self.placeholder_token) | 414 | prompt = [prompt.format(self.placeholder_token) |
| 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 417 | 417 | ||
| 418 | with self.accelerator.autocast(): | ||
| 419 | samples = pipeline( | 418 | samples = pipeline( |
| 420 | prompt=prompt, | 419 | prompt=prompt, |
| 421 | negative_prompt=nprompt, | 420 | negative_prompt=nprompt, |
| @@ -429,15 +428,15 @@ class Checkpointer: | |||
| 429 | output_type='pil' | 428 | output_type='pil' |
| 430 | )["sample"] | 429 | )["sample"] |
| 431 | 430 | ||
| 432 | all_samples += samples | 431 | all_samples += samples |
| 433 | 432 | ||
| 434 | del samples | 433 | del samples |
| 435 | 434 | ||
| 436 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 435 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 437 | image_grid.save(file_path) | 436 | image_grid.save(file_path) |
| 438 | 437 | ||
| 439 | del all_samples | 438 | del all_samples |
| 440 | del image_grid | 439 | del image_grid |
| 441 | 440 | ||
| 442 | del unwrapped | 441 | del unwrapped |
| 443 | del scheduler | 442 | del scheduler |
| @@ -623,7 +622,7 @@ def main(): | |||
| 623 | datamodule.setup() | 622 | datamodule.setup() |
| 624 | 623 | ||
| 625 | if args.num_class_images != 0: | 624 | if args.num_class_images != 0: |
| 626 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 625 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 627 | 626 | ||
| 628 | if len(missing_data) != 0: | 627 | if len(missing_data) != 0: |
| 629 | batched_data = [missing_data[i:i+args.sample_batch_size] | 628 | batched_data = [missing_data[i:i+args.sample_batch_size] |
| @@ -643,20 +642,20 @@ def main(): | |||
| 643 | pipeline.enable_attention_slicing() | 642 | pipeline.enable_attention_slicing() |
| 644 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 645 | 644 | ||
| 646 | for batch in batched_data: | 645 | with torch.inference_mode(): |
| 647 | image_name = [p[1] for p in batch] | 646 | for batch in batched_data: |
| 648 | prompt = [p[2].format(args.initializer_token) for p in batch] | 647 | image_name = [p.class_image_path for p in batch] |
| 649 | nprompt = [p[3] for p in batch] | 648 | prompt = [p.prompt.format(args.initializer_token) for p in batch] |
| 649 | nprompt = [p.nprompt for p in batch] | ||
| 650 | 650 | ||
| 651 | with accelerator.autocast(): | ||
| 652 | images = pipeline( | 651 | images = pipeline( |
| 653 | prompt=prompt, | 652 | prompt=prompt, |
| 654 | negative_prompt=nprompt, | 653 | negative_prompt=nprompt, |
| 655 | num_inference_steps=args.sample_steps | 654 | num_inference_steps=args.sample_steps |
| 656 | ).images | 655 | ).images |
| 657 | 656 | ||
| 658 | for i, image in enumerate(images): | 657 | for i, image in enumerate(images): |
| 659 | image.save(image_name[i]) | 658 | image.save(image_name[i]) |
| 660 | 659 | ||
| 661 | del pipeline | 660 | del pipeline |
| 662 | 661 | ||
