diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | data/dreambooth/csv.py | 1 | ||||
-rw-r--r-- | dreambooth.py | 98 | ||||
-rw-r--r-- | infer.py | 121 |
4 files changed, 175 insertions, 46 deletions
@@ -161,5 +161,6 @@ cython_debug/ | |||
161 | 161 | ||
162 | text-inversion-model/ | 162 | text-inversion-model/ |
163 | dreambooth-model/ | 163 | dreambooth-model/ |
164 | inference/ | ||
164 | conf*.json | 165 | conf*.json |
165 | v1-inference.yaml* | 166 | v1-inference.yaml* |
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index e70c068..14c13bb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -117,6 +117,7 @@ class CSVDataset(Dataset): | |||
117 | [ | 117 | [ |
118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | 118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
120 | transforms.RandomHorizontalFlip(), | ||
120 | transforms.ToTensor(), | 121 | transforms.ToTensor(), |
121 | transforms.Normalize([0.5], [0.5]), | 122 | transforms.Normalize([0.5], [0.5]), |
122 | ] | 123 | ] |
diff --git a/dreambooth.py b/dreambooth.py index c01cbe3..bc7a472 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -1,5 +1,4 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | ||
3 | import math | 2 | import math |
4 | import os | 3 | import os |
5 | import datetime | 4 | import datetime |
@@ -61,7 +60,8 @@ def parse_args(): | |||
61 | "--repeats", | 60 | "--repeats", |
62 | type=int, | 61 | type=int, |
63 | default=100, | 62 | default=100, |
64 | help="How many times to repeat the training data.") | 63 | help="How many times to repeat the training data." |
64 | ) | ||
65 | parser.add_argument( | 65 | parser.add_argument( |
66 | "--output_dir", | 66 | "--output_dir", |
67 | type=str, | 67 | type=str, |
@@ -72,7 +72,8 @@ def parse_args(): | |||
72 | "--seed", | 72 | "--seed", |
73 | type=int, | 73 | type=int, |
74 | default=None, | 74 | default=None, |
75 | help="A seed for reproducible training.") | 75 | help="A seed for reproducible training." |
76 | ) | ||
76 | parser.add_argument( | 77 | parser.add_argument( |
77 | "--resolution", | 78 | "--resolution", |
78 | type=int, | 79 | type=int, |
@@ -94,7 +95,7 @@ def parse_args(): | |||
94 | parser.add_argument( | 95 | parser.add_argument( |
95 | "--max_train_steps", | 96 | "--max_train_steps", |
96 | type=int, | 97 | type=int, |
97 | default=5000, | 98 | default=1000, |
98 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
99 | ) | 100 | ) |
100 | parser.add_argument( | 101 | parser.add_argument( |
@@ -184,7 +185,7 @@ def parse_args(): | |||
184 | parser.add_argument( | 185 | parser.add_argument( |
185 | "--checkpoint_frequency", | 186 | "--checkpoint_frequency", |
186 | type=int, | 187 | type=int, |
187 | default=500, | 188 | default=200, |
188 | help="How often to save a checkpoint and sample image", | 189 | help="How often to save a checkpoint and sample image", |
189 | ) | 190 | ) |
190 | parser.add_argument( | 191 | parser.add_argument( |
@@ -220,7 +221,7 @@ def parse_args(): | |||
220 | parser.add_argument( | 221 | parser.add_argument( |
221 | "--sample_steps", | 222 | "--sample_steps", |
222 | type=int, | 223 | type=int, |
223 | default=50, | 224 | default=80, |
224 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 225 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
225 | ) | 226 | ) |
226 | parser.add_argument( | 227 | parser.add_argument( |
@@ -381,7 +382,6 @@ class Checkpointer: | |||
381 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): | 382 | def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): |
382 | samples_path = f"{self.output_dir}/samples/{mode}" | 383 | samples_path = f"{self.output_dir}/samples/{mode}" |
383 | os.makedirs(samples_path, exist_ok=True) | 384 | os.makedirs(samples_path, exist_ok=True) |
384 | checker = NoCheck() | ||
385 | 385 | ||
386 | unwrapped = self.accelerator.unwrap_model(self.unet) | 386 | unwrapped = self.accelerator.unwrap_model(self.unet) |
387 | pipeline = StableDiffusionPipeline( | 387 | pipeline = StableDiffusionPipeline( |
@@ -507,6 +507,7 @@ def main(): | |||
507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 | 507 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 |
508 | pipeline = StableDiffusionPipeline.from_pretrained( | 508 | pipeline = StableDiffusionPipeline.from_pretrained( |
509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | 509 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) |
510 | pipeline.enable_attention_slicing() | ||
510 | pipeline.set_progress_bar_config(disable=True) | 511 | pipeline.set_progress_bar_config(disable=True) |
511 | 512 | ||
512 | num_new_images = args.num_class_images - cur_class_images | 513 | num_new_images = args.num_class_images - cur_class_images |
@@ -589,7 +590,11 @@ def main(): | |||
589 | 590 | ||
590 | # TODO (patil-suraj): laod scheduler using args | 591 | # TODO (patil-suraj): laod scheduler using args |
591 | noise_scheduler = DDPMScheduler( | 592 | noise_scheduler = DDPMScheduler( |
592 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | 593 | beta_start=0.00085, |
594 | beta_end=0.012, | ||
595 | beta_schedule="scaled_linear", | ||
596 | num_train_timesteps=1000, | ||
597 | tensor_format="pt" | ||
593 | ) | 598 | ) |
594 | 599 | ||
595 | def collate_fn(examples): | 600 | def collate_fn(examples): |
@@ -709,7 +714,7 @@ def main(): | |||
709 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 714 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
710 | 715 | ||
711 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 716 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) |
712 | local_progress_bar.set_description("Steps") | 717 | local_progress_bar.set_description("Steps ") |
713 | 718 | ||
714 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 719 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
715 | progress_bar.set_description("Global steps") | 720 | progress_bar.set_description("Global steps") |
@@ -723,31 +728,31 @@ def main(): | |||
723 | 728 | ||
724 | for step, batch in enumerate(train_dataloader): | 729 | for step, batch in enumerate(train_dataloader): |
725 | with accelerator.accumulate(unet): | 730 | with accelerator.accumulate(unet): |
726 | with accelerator.autocast(): | 731 | # Convert images to latent space |
727 | # Convert images to latent space | 732 | with torch.no_grad(): |
728 | with torch.no_grad(): | 733 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
729 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 734 | latents = latents * 0.18215 |
730 | latents = latents * 0.18215 | 735 | |
731 | 736 | # Sample noise that we'll add to the latents | |
732 | # Sample noise that we'll add to the latents | 737 | noise = torch.randn(latents.shape).to(latents.device) |
733 | noise = torch.randn(latents.shape).to(latents.device) | 738 | bsz = latents.shape[0] |
734 | bsz = latents.shape[0] | 739 | # Sample a random timestep for each image |
735 | # Sample a random timestep for each image | 740 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
736 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 741 | (bsz,), device=latents.device) |
737 | (bsz,), device=latents.device) | 742 | timesteps = timesteps.long() |
738 | timesteps = timesteps.long() | 743 | |
739 | 744 | # Add noise to the latents according to the noise magnitude at each timestep | |
740 | # Add noise to the latents according to the noise magnitude at each timestep | 745 | # (this is the forward diffusion process) |
741 | # (this is the forward diffusion process) | 746 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
742 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 747 | |
743 | 748 | # Get the text embedding for conditioning | |
744 | # Get the text embedding for conditioning | 749 | with torch.no_grad(): |
745 | with torch.no_grad(): | 750 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
746 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 751 | |
747 | 752 | # Predict the noise residual | |
748 | # Predict the noise residual | 753 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
749 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
750 | 754 | ||
755 | with accelerator.autocast(): | ||
751 | if args.with_prior_preservation: | 756 | if args.with_prior_preservation: |
752 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 757 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
753 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 758 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
@@ -766,12 +771,12 @@ def main(): | |||
766 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 771 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
767 | 772 | ||
768 | accelerator.backward(loss) | 773 | accelerator.backward(loss) |
769 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 774 | if accelerator.sync_gradients: |
770 | 775 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | |
771 | optimizer.step() | 776 | optimizer.step() |
772 | if not accelerator.optimizer_step_was_skipped: | 777 | if not accelerator.optimizer_step_was_skipped: |
773 | lr_scheduler.step() | 778 | lr_scheduler.step() |
774 | optimizer.zero_grad() | 779 | optimizer.zero_grad(set_to_none=True) |
775 | 780 | ||
776 | loss = loss.detach().item() | 781 | loss = loss.detach().item() |
777 | train_loss += loss | 782 | train_loss += loss |
@@ -804,7 +809,7 @@ def main(): | |||
804 | val_loss = 0.0 | 809 | val_loss = 0.0 |
805 | 810 | ||
806 | for step, batch in enumerate(val_dataloader): | 811 | for step, batch in enumerate(val_dataloader): |
807 | with torch.no_grad(), accelerator.autocast(): | 812 | with torch.no_grad(): |
808 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 813 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
809 | latents = latents * 0.18215 | 814 | latents = latents * 0.18215 |
810 | 815 | ||
@@ -822,18 +827,19 @@ def main(): | |||
822 | 827 | ||
823 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 828 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
824 | 829 | ||
825 | if args.with_prior_preservation: | 830 | with accelerator.autocast(): |
826 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 831 | if args.with_prior_preservation: |
827 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 832 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
833 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
828 | 834 | ||
829 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 835 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
830 | 836 | ||
831 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 837 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, |
832 | reduction="none").mean([1, 2, 3]).mean() | 838 | reduction="none").mean([1, 2, 3]).mean() |
833 | 839 | ||
834 | loss = loss + args.prior_loss_weight * prior_loss | 840 | loss = loss + args.prior_loss_weight * prior_loss |
835 | else: | 841 | else: |
836 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 842 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
837 | 843 | ||
838 | loss = loss.detach().item() | 844 | loss = loss.detach().item() |
839 | val_loss += loss | 845 | val_loss += loss |
diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b9e9ff7 --- /dev/null +++ b/infer.py | |||
@@ -0,0 +1,121 @@ | |||
1 | import argparse | ||
2 | import datetime | ||
3 | from pathlib import Path | ||
4 | from torch import autocast | ||
5 | from diffusers import StableDiffusionPipeline | ||
6 | import torch | ||
7 | import json | ||
8 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler | ||
9 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | ||
10 | from slugify import slugify | ||
11 | from pipelines.stable_diffusion.no_check import NoCheck | ||
12 | |||
13 | model_id = "path-to-your-trained-model" | ||
14 | |||
15 | prompt = "A photo of sks dog in a bucket" | ||
16 | |||
17 | |||
18 | def parse_args(): | ||
19 | parser = argparse.ArgumentParser( | ||
20 | description="Simple example of a training script." | ||
21 | ) | ||
22 | parser.add_argument( | ||
23 | "--model", | ||
24 | type=str, | ||
25 | default=None, | ||
26 | ) | ||
27 | parser.add_argument( | ||
28 | "--prompt", | ||
29 | type=str, | ||
30 | default=None, | ||
31 | ) | ||
32 | parser.add_argument( | ||
33 | "--batch_size", | ||
34 | type=int, | ||
35 | default=1, | ||
36 | ) | ||
37 | parser.add_argument( | ||
38 | "--batch_num", | ||
39 | type=int, | ||
40 | default=50, | ||
41 | ) | ||
42 | parser.add_argument( | ||
43 | "--steps", | ||
44 | type=int, | ||
45 | default=80, | ||
46 | ) | ||
47 | parser.add_argument( | ||
48 | "--scale", | ||
49 | type=int, | ||
50 | default=7.5, | ||
51 | ) | ||
52 | parser.add_argument( | ||
53 | "--seed", | ||
54 | type=int, | ||
55 | default=None, | ||
56 | ) | ||
57 | parser.add_argument( | ||
58 | "--output_dir", | ||
59 | type=str, | ||
60 | default="inference", | ||
61 | ) | ||
62 | parser.add_argument( | ||
63 | "--config", | ||
64 | type=str, | ||
65 | default=None, | ||
66 | ) | ||
67 | |||
68 | args = parser.parse_args() | ||
69 | if args.config is not None: | ||
70 | with open(args.config, 'rt') as f: | ||
71 | args = parser.parse_args( | ||
72 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
73 | |||
74 | return args | ||
75 | |||
76 | |||
77 | def main(): | ||
78 | args = parse_args() | ||
79 | |||
80 | seed = args.seed or torch.random.seed() | ||
81 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
82 | |||
83 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
84 | output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}") | ||
85 | output_dir.mkdir(parents=True, exist_ok=True) | ||
86 | |||
87 | tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16) | ||
88 | text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16) | ||
89 | vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16) | ||
90 | unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16) | ||
91 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16) | ||
92 | |||
93 | pipeline = StableDiffusionPipeline( | ||
94 | text_encoder=text_encoder, | ||
95 | vae=vae, | ||
96 | unet=unet, | ||
97 | tokenizer=tokenizer, | ||
98 | scheduler=PNDMScheduler( | ||
99 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | ||
100 | ), | ||
101 | safety_checker=NoCheck(), | ||
102 | feature_extractor=feature_extractor | ||
103 | ) | ||
104 | pipeline.enable_attention_slicing() | ||
105 | pipeline.to("cuda") | ||
106 | |||
107 | with autocast("cuda"): | ||
108 | for i in range(args.batch_num): | ||
109 | images = pipeline( | ||
110 | [args.prompt] * args.batch_size, | ||
111 | num_inference_steps=args.steps, | ||
112 | guidance_scale=args.scale, | ||
113 | generator=generator, | ||
114 | ).images | ||
115 | |||
116 | for j, image in enumerate(images): | ||
117 | image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg")) | ||
118 | |||
119 | |||
120 | if __name__ == "__main__": | ||
121 | main() | ||