diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 22:16:13 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 22:16:13 +0200 |
| commit | a8a5abae42f6f42056cc27e0cf5313aab080c3a7 (patch) | |
| tree | 32c163bbc58aa2f827c5ba5108df81dc14fbe130 | |
| parent | Incorporate upstream changes (diff) | |
| download | textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.gz textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.tar.bz2 textual-inversion-diff-a8a5abae42f6f42056cc27e0cf5313aab080c3a7.zip | |
Various improvements, added inference script
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | data/dreambooth/csv.py | 1 | ||||
| -rw-r--r-- | dreambooth.py | 90 | ||||
| -rw-r--r-- | infer.py | 121 |
4 files changed, 171 insertions, 42 deletions
| @@ -161,5 +161,6 @@ cython_debug/ | |||
| 161 | 161 | ||
| 162 | text-inversion-model/ | 162 | text-inversion-model/ |
| 163 | dreambooth-model/ | 163 | dreambooth-model/ |
| 164 | inference/ | ||
| 164 | conf*.json | 165 | conf*.json |
| 165 | v1-inference.yaml* | 166 | v1-inference.yaml* |
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index e70c068..14c13bb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -117,6 +117,7 @@ class CSVDataset(Dataset): | |||
| 117 | [ | 117 | [ |
| 118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | 118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
| 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
| 120 | transforms.RandomHorizontalFlip(), | ||
| 120 | transforms.ToTensor(), | 121 | transforms.ToTensor(), |
| 121 | transforms.Normalize([0.5], [0.5]), | 122 | transforms.Normalize([0.5], [0.5]), |
| 122 | ] | 123 | ] |
diff --git a/dreambooth.py b/dreambooth.py index c01cbe3..bc7a472 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -1,5 +1,4 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 3 | import math | 2 | import math |
| 4 | import os | 3 | import os |
| 5 | import datetime | 4 | import datetime |
| @@ -61,7 +60,8 @@ def parse_args(): | |||
| 61 | "--repeats", | 60 | "--repeats", |
| 62 | type=int, | 61 | type=int, |
| 63 | default=100, | 62 | default=100, |
| 64 | help="How many times to repeat the training data.") | 63 | help="How many times to repeat the training data." |
| 64 | ) | ||
| 65 | parser.add_argument( | 65 | parser.add_argument( |
| 66 | "--output_dir", | 66 | "--output_dir", |
| 67 | type=str, | 67 | type=str, |
| @@ -72,7 +72,8 @@ def parse_args(): | |||
| 72 | "--seed", | 72 | "--seed", |
| 73 | type=int, | 73 | type=int, |
| 74 | default=None, | 74 | default=None, |
| 75 | help="A seed for reproducible training.") | 75 | help="A seed for reproducible training." |
| 76 | ) | ||
| 76 | parser.add_argument( | 77 | parser.add_argument( |
| 77 | "--resolution", | 78 | "--resolution", |
| 78 | type=int, | 79 | type=int, |
| @@ -94,7 +95,7 @@ def parse_args(): | |||
| 94 | parser.add_argument( | 95 | parser.add_argument( |
| 95 | "--max_train_steps", | 96 | "--max_train_steps", |
| 96 | type=int, | 97 | type=int, |
| 97 | default=5000, | 98 | default=1000, |
| 98 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 99 | ) | 100 | ) |
| 100 | parser.add_argument( | 101 | parser.add_argument( |
| @@ -184,7 +185,7 @@ def parse_args(): | |||
| 184 | parser.add_argument( | 185 | parser.add_argument( |
| 185 | "--checkpoint_frequency", | 186 | "--checkpoint_frequency", |
| 186 | type=int, | 187 | type=int, |
| 187 | default=500, | 188 | default=200, |
| 188 | help="How often to save a checkpoint and sample image", | 189 | help="How often to save a checkpoint and sample image", |
| 189 | ) | 190 | ) |
| 190 | parser.add_argument( | 191 | parser.add_argument( |
| @@ -220,7 +221,7 @@ def parse_args(): | |||
| 220 | parser.add_argument( | 221 | parser.add_argument( |
| 221 | "--sample_steps", | 222 | "--sample_steps", |
| 222 | type=int, | 223 | type=int, |
| 223 | default=50, | 224 | default=80, |
| 224 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 225 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 225 | ) | 226 | ) |
| 226 | parser.add_argument( | 227 | parser.add_argument( |
| @@ -381,7 +382,6 @@ class Checkpointer: | |||
| 381 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): |
| 382 | samples_path = f"{self.output_dir}/samples/{mode}" | 383 | samples_path = f"{self.output_dir}/samples/{mode}" |
| 383 | os.makedirs(samples_path, exist_ok=True) | 384 | os.makedirs(samples_path, exist_ok=True) |
| 384 | checker = NoCheck() | ||
| 385 | 385 | ||
| 386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 386 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 387 | pipeline = StableDiffusionPipeline( | 387 | pipeline = StableDiffusionPipeline( |
| @@ -507,6 +507,7 @@ def main(): | |||
| 507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 | 507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 |
| 508 | pipeline = StableDiffusionPipeline.from_pretrained( | 508 | pipeline = StableDiffusionPipeline.from_pretrained( |
| 509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
| 510 | pipeline.enable_attention_slicing() | ||
| 510 | pipeline.set_progress_bar_config(disable=True) | 511 | pipeline.set_progress_bar_config(disable=True) |
| 511 | 512 | ||
| 512 | num_new_images = args.num_class_images - cur_class_images | 513 | num_new_images = args.num_class_images - cur_class_images |
| @@ -589,7 +590,11 @@ def main(): | |||
| 589 | 590 | ||
| 590 | # TODO (patil-suraj): laod scheduler using args | 591 | # TODO (patil-suraj): laod scheduler using args |
| 591 | noise_scheduler = DDPMScheduler( | 592 | noise_scheduler = DDPMScheduler( |
| 592 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | 593 | beta_start=0.00085, |
| 594 | beta_end=0.012, | ||
| 595 | beta_schedule="scaled_linear", | ||
| 596 | num_train_timesteps=1000, | ||
| 597 | tensor_format="pt" | ||
| 593 | ) | 598 | ) |
| 594 | 599 | ||
| 595 | def collate_fn(examples): | 600 | def collate_fn(examples): |
| @@ -709,7 +714,7 @@ def main(): | |||
| 709 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 714 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 710 | 715 | ||
| 711 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 716 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
| 712 | local_progress_bar.set_description("Steps") | 717 | local_progress_bar.set_description("Steps ") |
| 713 | 718 | ||
| 714 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 719 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
| 715 | progress_bar.set_description("Global steps") | 720 | progress_bar.set_description("Global steps") |
| @@ -723,31 +728,31 @@ def main(): | |||
| 723 | 728 | ||
| 724 | for step, batch in enumerate(train_dataloader): | 729 | for step, batch in enumerate(train_dataloader): |
| 725 | with accelerator.accumulate(unet): | 730 | with accelerator.accumulate(unet): |
| 726 | with accelerator.autocast(): | 731 | # Convert images to latent space |
| 727 | # Convert images to latent space | 732 | with torch.no_grad(): |
| 728 | with torch.no_grad(): | 733 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 729 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 734 | latents = latents * 0.18215 |
| 730 | latents = latents * 0.18215 | ||
| 731 | 735 | ||
| 732 | # Sample noise that we'll add to the latents | 736 | # Sample noise that we'll add to the latents |
| 733 | noise = torch.randn(latents.shape).to(latents.device) | 737 | noise = torch.randn(latents.shape).to(latents.device) |
| 734 | bsz = latents.shape[0] | 738 | bsz = latents.shape[0] |
| 735 | # Sample a random timestep for each image | 739 | # Sample a random timestep for each image |
| 736 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 740 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 737 | (bsz,), device=latents.device) | 741 | (bsz,), device=latents.device) |
| 738 | timesteps = timesteps.long() | 742 | timesteps = timesteps.long() |
| 739 | 743 | ||
| 740 | # Add noise to the latents according to the noise magnitude at each timestep | 744 | # Add noise to the latents according to the noise magnitude at each timestep |
| 741 | # (this is the forward diffusion process) | 745 | # (this is the forward diffusion process) |
| 742 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 746 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 743 | 747 | ||
| 744 | # Get the text embedding for conditioning | 748 | # Get the text embedding for conditioning |
| 745 | with torch.no_grad(): | 749 | with torch.no_grad(): |
| 746 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 750 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 747 | 751 | ||
| 748 | # Predict the noise residual | 752 | # Predict the noise residual |
| 749 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 753 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 750 | 754 | ||
| 755 | with accelerator.autocast(): | ||
| 751 | if args.with_prior_preservation: | 756 | if args.with_prior_preservation: |
| 752 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 757 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 753 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 758 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| @@ -766,12 +771,12 @@ def main(): | |||
| 766 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 771 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 767 | 772 | ||
| 768 | accelerator.backward(loss) | 773 | accelerator.backward(loss) |
| 769 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 774 | if accelerator.sync_gradients: |
| 770 | 775 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | |
| 771 | optimizer.step() | 776 | optimizer.step() |
| 772 | if not accelerator.optimizer_step_was_skipped: | 777 | if not accelerator.optimizer_step_was_skipped: |
| 773 | lr_scheduler.step() | 778 | lr_scheduler.step() |
| 774 | optimizer.zero_grad() | 779 | optimizer.zero_grad(set_to_none=True) |
| 775 | 780 | ||
| 776 | loss = loss.detach().item() | 781 | loss = loss.detach().item() |
| 777 | train_loss += loss | 782 | train_loss += loss |
| @@ -804,7 +809,7 @@ def main(): | |||
| 804 | val_loss = 0.0 | 809 | val_loss = 0.0 |
| 805 | 810 | ||
| 806 | for step, batch in enumerate(val_dataloader): | 811 | for step, batch in enumerate(val_dataloader): |
| 807 | with torch.no_grad(), accelerator.autocast(): | 812 | with torch.no_grad(): |
| 808 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 813 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 809 | latents = latents * 0.18215 | 814 | latents = latents * 0.18215 |
| 810 | 815 | ||
| @@ -822,18 +827,19 @@ def main(): | |||
| 822 | 827 | ||
| 823 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 828 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 824 | 829 | ||
| 825 | if args.with_prior_preservation: | 830 | with accelerator.autocast(): |
| 826 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 831 | if args.with_prior_preservation: |
| 827 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 832 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 833 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 828 | 834 | ||
| 829 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 830 | 836 | ||
| 831 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 837 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, |
| 832 | reduction="none").mean([1, 2, 3]).mean() | 838 | reduction="none").mean([1, 2, 3]).mean() |
| 833 | 839 | ||
| 834 | loss = loss + args.prior_loss_weight * prior_loss | 840 | loss = loss + args.prior_loss_weight * prior_loss |
| 835 | else: | 841 | else: |
| 836 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 842 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 837 | 843 | ||
| 838 | loss = loss.detach().item() | 844 | loss = loss.detach().item() |
| 839 | val_loss += loss | 845 | val_loss += loss |
diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b9e9ff7 --- /dev/null +++ b/infer.py | |||
| @@ -0,0 +1,121 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | from pathlib import Path | ||
| 4 | from torch import autocast | ||
| 5 | from diffusers import StableDiffusionPipeline | ||
| 6 | import torch | ||
| 7 | import json | ||
| 8 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler | ||
| 9 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | ||
| 10 | from slugify import slugify | ||
| 11 | from pipelines.stable_diffusion.no_check import NoCheck | ||
| 12 | |||
| 13 | model_id = "path-to-your-trained-model" | ||
| 14 | |||
| 15 | prompt = "A photo of sks dog in a bucket" | ||
| 16 | |||
| 17 | |||
| 18 | def parse_args(): | ||
| 19 | parser = argparse.ArgumentParser( | ||
| 20 | description="Simple example of a training script." | ||
| 21 | ) | ||
| 22 | parser.add_argument( | ||
| 23 | "--model", | ||
| 24 | type=str, | ||
| 25 | default=None, | ||
| 26 | ) | ||
| 27 | parser.add_argument( | ||
| 28 | "--prompt", | ||
| 29 | type=str, | ||
| 30 | default=None, | ||
| 31 | ) | ||
| 32 | parser.add_argument( | ||
| 33 | "--batch_size", | ||
| 34 | type=int, | ||
| 35 | default=1, | ||
| 36 | ) | ||
| 37 | parser.add_argument( | ||
| 38 | "--batch_num", | ||
| 39 | type=int, | ||
| 40 | default=50, | ||
| 41 | ) | ||
| 42 | parser.add_argument( | ||
| 43 | "--steps", | ||
| 44 | type=int, | ||
| 45 | default=80, | ||
| 46 | ) | ||
| 47 | parser.add_argument( | ||
| 48 | "--scale", | ||
| 49 | type=int, | ||
| 50 | default=7.5, | ||
| 51 | ) | ||
| 52 | parser.add_argument( | ||
| 53 | "--seed", | ||
| 54 | type=int, | ||
| 55 | default=None, | ||
| 56 | ) | ||
| 57 | parser.add_argument( | ||
| 58 | "--output_dir", | ||
| 59 | type=str, | ||
| 60 | default="inference", | ||
| 61 | ) | ||
| 62 | parser.add_argument( | ||
| 63 | "--config", | ||
| 64 | type=str, | ||
| 65 | default=None, | ||
| 66 | ) | ||
| 67 | |||
| 68 | args = parser.parse_args() | ||
| 69 | if args.config is not None: | ||
| 70 | with open(args.config, 'rt') as f: | ||
| 71 | args = parser.parse_args( | ||
| 72 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 73 | |||
| 74 | return args | ||
| 75 | |||
| 76 | |||
| 77 | def main(): | ||
| 78 | args = parse_args() | ||
| 79 | |||
| 80 | seed = args.seed or torch.random.seed() | ||
| 81 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
| 82 | |||
| 83 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 84 | output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") | ||
| 85 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 86 | |||
| 87 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | ||
| 88 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
| 89 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
| 90 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
| 91 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) | ||
| 92 | |||
| 93 | pipeline = StableDiffusionPipeline( | ||
| 94 | text_encoder=text_encoder, | ||
| 95 | vae=vae, | ||
| 96 | unet=unet, | ||
| 97 | tokenizer=tokenizer, | ||
| 98 | scheduler=PNDMScheduler( | ||
| 99 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
| 100 | ), | ||
| 101 | safety_checker=NoCheck(), | ||
| 102 | feature_extractor=feature_extractor | ||
| 103 | ) | ||
| 104 | pipeline.enable_attention_slicing() | ||
| 105 | pipeline.to("cuda") | ||
| 106 | |||
| 107 | with autocast("cuda"): | ||
| 108 | for i in range(args.batch_num): | ||
| 109 | images = pipeline( | ||
| 110 | [args.prompt] * args.batch_size, | ||
| 111 | num_inference_steps=args.steps, | ||
| 112 | guidance_scale=args.scale, | ||
| 113 | generator=generator, | ||
| 114 | ).images | ||
| 115 | |||
| 116 | for j, image in enumerate(images): | ||
| 117 | image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) | ||
| 118 | |||
| 119 | |||
| 120 | if __name__ == "__main__": | ||
| 121 | main() | ||
