summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--data/dreambooth/csv.py1
-rw-r--r--dreambooth.py98
-rw-r--r--infer.py121
4 files changed, 175 insertions, 46 deletions
diff --git a/.gitignore b/.gitignore
index 91a5e07..218c628 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,5 +161,6 @@ cython_debug/
161 161
162text-inversion-model/ 162text-inversion-model/
163dreambooth-model/ 163dreambooth-model/
164inference/
164conf*.json 165conf*.json
165v1-inference.yaml* 166v1-inference.yaml*
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index e70c068..14c13bb 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -117,6 +117,7 @@ class CSVDataset(Dataset):
117 [ 117 [
118 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 118 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
120 transforms.RandomHorizontalFlip(),
120 transforms.ToTensor(), 121 transforms.ToTensor(),
121 transforms.Normalize([0.5], [0.5]), 122 transforms.Normalize([0.5], [0.5]),
122 ] 123 ]
diff --git a/dreambooth.py b/dreambooth.py
index c01cbe3..bc7a472 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,5 +1,4 @@
1import argparse 1import argparse
2import itertools
3import math 2import math
4import os 3import os
5import datetime 4import datetime
@@ -61,7 +60,8 @@ def parse_args():
61 "--repeats", 60 "--repeats",
62 type=int, 61 type=int,
63 default=100, 62 default=100,
64 help="How many times to repeat the training data.") 63 help="How many times to repeat the training data."
64 )
65 parser.add_argument( 65 parser.add_argument(
66 "--output_dir", 66 "--output_dir",
67 type=str, 67 type=str,
@@ -72,7 +72,8 @@ def parse_args():
72 "--seed", 72 "--seed",
73 type=int, 73 type=int,
74 default=None, 74 default=None,
75 help="A seed for reproducible training.") 75 help="A seed for reproducible training."
76 )
76 parser.add_argument( 77 parser.add_argument(
77 "--resolution", 78 "--resolution",
78 type=int, 79 type=int,
@@ -94,7 +95,7 @@ def parse_args():
94 parser.add_argument( 95 parser.add_argument(
95 "--max_train_steps", 96 "--max_train_steps",
96 type=int, 97 type=int,
97 default=5000, 98 default=1000,
98 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 99 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
99 ) 100 )
100 parser.add_argument( 101 parser.add_argument(
@@ -184,7 +185,7 @@ def parse_args():
184 parser.add_argument( 185 parser.add_argument(
185 "--checkpoint_frequency", 186 "--checkpoint_frequency",
186 type=int, 187 type=int,
187 default=500, 188 default=200,
188 help="How often to save a checkpoint and sample image", 189 help="How often to save a checkpoint and sample image",
189 ) 190 )
190 parser.add_argument( 191 parser.add_argument(
@@ -220,7 +221,7 @@ def parse_args():
220 parser.add_argument( 221 parser.add_argument(
221 "--sample_steps", 222 "--sample_steps",
222 type=int, 223 type=int,
223 default=50, 224 default=80,
224 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 225 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
225 ) 226 )
226 parser.add_argument( 227 parser.add_argument(
@@ -381,7 +382,6 @@ class Checkpointer:
381 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): 382 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps):
382 samples_path = f"{self.output_dir}/samples/{mode}" 383 samples_path = f"{self.output_dir}/samples/{mode}"
383 os.makedirs(samples_path, exist_ok=True) 384 os.makedirs(samples_path, exist_ok=True)
384 checker = NoCheck()
385 385
386 unwrapped = self.accelerator.unwrap_model(self.unet) 386 unwrapped = self.accelerator.unwrap_model(self.unet)
387 pipeline = StableDiffusionPipeline( 387 pipeline = StableDiffusionPipeline(
@@ -507,6 +507,7 @@ def main():
507 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 507 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
508 pipeline = StableDiffusionPipeline.from_pretrained( 508 pipeline = StableDiffusionPipeline.from_pretrained(
509 args.pretrained_model_name_or_path, torch_dtype=torch_dtype) 509 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
510 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True) 511 pipeline.set_progress_bar_config(disable=True)
511 512
512 num_new_images = args.num_class_images - cur_class_images 513 num_new_images = args.num_class_images - cur_class_images
@@ -589,7 +590,11 @@ def main():
589 590
590 # TODO (patil-suraj): laod scheduler using args 591 # TODO (patil-suraj): laod scheduler using args
591 noise_scheduler = DDPMScheduler( 592 noise_scheduler = DDPMScheduler(
592 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" 593 beta_start=0.00085,
594 beta_end=0.012,
595 beta_schedule="scaled_linear",
596 num_train_timesteps=1000,
597 tensor_format="pt"
593 ) 598 )
594 599
595 def collate_fn(examples): 600 def collate_fn(examples):
@@ -709,7 +714,7 @@ def main():
709 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 714 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
710 715
711 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 716 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
712 local_progress_bar.set_description("Steps") 717 local_progress_bar.set_description("Steps ")
713 718
714 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 719 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
715 progress_bar.set_description("Global steps") 720 progress_bar.set_description("Global steps")
@@ -723,31 +728,31 @@ def main():
723 728
724 for step, batch in enumerate(train_dataloader): 729 for step, batch in enumerate(train_dataloader):
725 with accelerator.accumulate(unet): 730 with accelerator.accumulate(unet):
726 with accelerator.autocast(): 731 # Convert images to latent space
727 # Convert images to latent space 732 with torch.no_grad():
728 with torch.no_grad(): 733 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
729 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 734 latents = latents * 0.18215
730 latents = latents * 0.18215 735
731 736 # Sample noise that we'll add to the latents
732 # Sample noise that we'll add to the latents 737 noise = torch.randn(latents.shape).to(latents.device)
733 noise = torch.randn(latents.shape).to(latents.device) 738 bsz = latents.shape[0]
734 bsz = latents.shape[0] 739 # Sample a random timestep for each image
735 # Sample a random timestep for each image 740 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
736 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 741 (bsz,), device=latents.device)
737 (bsz,), device=latents.device) 742 timesteps = timesteps.long()
738 timesteps = timesteps.long() 743
739 744 # Add noise to the latents according to the noise magnitude at each timestep
740 # Add noise to the latents according to the noise magnitude at each timestep 745 # (this is the forward diffusion process)
741 # (this is the forward diffusion process) 746 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
742 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 747
743 748 # Get the text embedding for conditioning
744 # Get the text embedding for conditioning 749 with torch.no_grad():
745 with torch.no_grad(): 750 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
746 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 751
747 752 # Predict the noise residual
748 # Predict the noise residual 753 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
749 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
750 754
755 with accelerator.autocast():
751 if args.with_prior_preservation: 756 if args.with_prior_preservation:
752 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 757 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
753 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 758 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
@@ -766,12 +771,12 @@ def main():
766 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 771 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
767 772
768 accelerator.backward(loss) 773 accelerator.backward(loss)
769 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 774 if accelerator.sync_gradients:
770 775 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
771 optimizer.step() 776 optimizer.step()
772 if not accelerator.optimizer_step_was_skipped: 777 if not accelerator.optimizer_step_was_skipped:
773 lr_scheduler.step() 778 lr_scheduler.step()
774 optimizer.zero_grad() 779 optimizer.zero_grad(set_to_none=True)
775 780
776 loss = loss.detach().item() 781 loss = loss.detach().item()
777 train_loss += loss 782 train_loss += loss
@@ -804,7 +809,7 @@ def main():
804 val_loss = 0.0 809 val_loss = 0.0
805 810
806 for step, batch in enumerate(val_dataloader): 811 for step, batch in enumerate(val_dataloader):
807 with torch.no_grad(), accelerator.autocast(): 812 with torch.no_grad():
808 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 813 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
809 latents = latents * 0.18215 814 latents = latents * 0.18215
810 815
@@ -822,18 +827,19 @@ def main():
822 827
823 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 828 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
824 829
825 if args.with_prior_preservation: 830 with accelerator.autocast():
826 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 831 if args.with_prior_preservation:
827 noise, noise_prior = torch.chunk(noise, 2, dim=0) 832 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
833 noise, noise_prior = torch.chunk(noise, 2, dim=0)
828 834
829 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 835 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
830 836
831 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 837 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
832 reduction="none").mean([1, 2, 3]).mean() 838 reduction="none").mean([1, 2, 3]).mean()
833 839
834 loss = loss + args.prior_loss_weight * prior_loss 840 loss = loss + args.prior_loss_weight * prior_loss
835 else: 841 else:
836 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 842 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
837 843
838 loss = loss.detach().item() 844 loss = loss.detach().item()
839 val_loss += loss 845 val_loss += loss
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000..b9e9ff7
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,121 @@
1import argparse
2import datetime
3from pathlib import Path
4from torch import autocast
5from diffusers import StableDiffusionPipeline
6import torch
7import json
8from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler
9from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
10from slugify import slugify
11from pipelines.stable_diffusion.no_check import NoCheck
12
13model_id = "path-to-your-trained-model"
14
15prompt = "A photo of sks dog in a bucket"
16
17
18def parse_args():
19 parser = argparse.ArgumentParser(
20 description="Simple example of a training script."
21 )
22 parser.add_argument(
23 "--model",
24 type=str,
25 default=None,
26 )
27 parser.add_argument(
28 "--prompt",
29 type=str,
30 default=None,
31 )
32 parser.add_argument(
33 "--batch_size",
34 type=int,
35 default=1,
36 )
37 parser.add_argument(
38 "--batch_num",
39 type=int,
40 default=50,
41 )
42 parser.add_argument(
43 "--steps",
44 type=int,
45 default=80,
46 )
47 parser.add_argument(
48 "--scale",
49 type=int,
50 default=7.5,
51 )
52 parser.add_argument(
53 "--seed",
54 type=int,
55 default=None,
56 )
57 parser.add_argument(
58 "--output_dir",
59 type=str,
60 default="inference",
61 )
62 parser.add_argument(
63 "--config",
64 type=str,
65 default=None,
66 )
67
68 args = parser.parse_args()
69 if args.config is not None:
70 with open(args.config, 'rt') as f:
71 args = parser.parse_args(
72 namespace=argparse.Namespace(**json.load(f)["args"]))
73
74 return args
75
76
77def main():
78 args = parse_args()
79
80 seed = args.seed or torch.random.seed()
81 generator = torch.Generator(device="cuda").manual_seed(seed)
82
83 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
84 output_dir = Path(args.output_dir).joinpath(f"{now}_{seed}_{slugify(args.prompt)[:80]}")
85 output_dir.mkdir(parents=True, exist_ok=True)
86
87 tokenizer = CLIPTokenizer.from_pretrained(args.model + '/tokenizer', torch_dtype=torch.bfloat16)
88 text_encoder = CLIPTextModel.from_pretrained(args.model + '/text_encoder', torch_dtype=torch.bfloat16)
89 vae = AutoencoderKL.from_pretrained(args.model + '/vae', torch_dtype=torch.bfloat16)
90 unet = UNet2DConditionModel.from_pretrained(args.model + '/unet', torch_dtype=torch.bfloat16)
91 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.bfloat16)
92
93 pipeline = StableDiffusionPipeline(
94 text_encoder=text_encoder,
95 vae=vae,
96 unet=unet,
97 tokenizer=tokenizer,
98 scheduler=PNDMScheduler(
99 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
100 ),
101 safety_checker=NoCheck(),
102 feature_extractor=feature_extractor
103 )
104 pipeline.enable_attention_slicing()
105 pipeline.to("cuda")
106
107 with autocast("cuda"):
108 for i in range(args.batch_num):
109 images = pipeline(
110 [args.prompt] * args.batch_size,
111 num_inference_steps=args.steps,
112 guidance_scale=args.scale,
113 generator=generator,
114 ).images
115
116 for j, image in enumerate(images):
117 image.save(output_dir.joinpath(f"{i * args.batch_size + j}.jpg"))
118
119
120if __name__ == "__main__":
121 main()