summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py28
1 files changed, 10 insertions, 18 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index ae31377..fa3a22b 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -26,6 +26,7 @@ from slugify import slugify
26from schedulers.scheduling_euler_a import EulerAScheduler 26from schedulers.scheduling_euler_a import EulerAScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 28from data.csv import CSVDataModule
29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
31 32
@@ -147,7 +148,7 @@ def parse_args():
147 parser.add_argument( 148 parser.add_argument(
148 "--learning_rate_text", 149 "--learning_rate_text",
149 type=float, 150 type=float,
150 default=1e-6, 151 default=5e-6,
151 help="Initial learning rate (after the potential warmup period) to use.", 152 help="Initial learning rate (after the potential warmup period) to use.",
152 ) 153 )
153 parser.add_argument( 154 parser.add_argument(
@@ -470,7 +471,7 @@ class Checkpointer:
470 for i in range(self.sample_batches): 471 for i in range(self.sample_batches):
471 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 472 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
472 prompt = [ 473 prompt = [
473 [p.format(self.instance_identifier) for p in prompt] 474 prompt.format(self.instance_identifier)
474 for batch in batches 475 for batch in batches
475 for prompt in batch["prompts"] 476 for prompt in batch["prompts"]
476 ][:self.sample_batch_size] 477 ][:self.sample_batch_size]
@@ -573,6 +574,8 @@ def main():
573 device=accelerator.device 574 device=accelerator.device
574 ) if args.use_ema else None 575 ) if args.use_ema else None
575 576
577 prompt_processor = PromptProcessor(tokenizer, text_encoder)
578
576 if args.gradient_checkpointing: 579 if args.gradient_checkpointing:
577 unet.enable_gradient_checkpointing() 580 unet.enable_gradient_checkpointing()
578 581
@@ -663,7 +666,7 @@ def main():
663 pixel_values = torch.stack(pixel_values) 666 pixel_values = torch.stack(pixel_values)
664 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 667 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
665 668
666 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 669 input_ids = prompt_processor.unify_input_ids(input_ids)
667 670
668 batch = { 671 batch = {
669 "prompts": prompts, 672 "prompts": prompts,
@@ -673,21 +676,10 @@ def main():
673 } 676 }
674 return batch 677 return batch
675 678
676 def encode_input_ids(input_ids):
677 text_embeddings = []
678
679 for ids in input_ids:
680 embeddings = text_encoder(ids)[0]
681 embeddings = embeddings.reshape((1, -1, 768))
682 text_embeddings.append(embeddings)
683
684 text_embeddings = torch.cat(text_embeddings)
685 return text_embeddings
686
687 datamodule = CSVDataModule( 679 datamodule = CSVDataModule(
688 data_file=args.train_data_file, 680 data_file=args.train_data_file,
689 batch_size=args.train_batch_size, 681 batch_size=args.train_batch_size,
690 tokenizer=tokenizer, 682 prompt_processor=prompt_processor,
691 instance_identifier=args.instance_identifier, 683 instance_identifier=args.instance_identifier,
692 class_identifier=args.class_identifier, 684 class_identifier=args.class_identifier,
693 class_subdir="cls", 685 class_subdir="cls",
@@ -727,7 +719,7 @@ def main():
727 with torch.inference_mode(): 719 with torch.inference_mode():
728 for batch in batched_data: 720 for batch in batched_data:
729 image_name = [item.class_image_path for item in batch] 721 image_name = [item.class_image_path for item in batch]
730 prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] 722 prompt = [item.prompt.format(args.class_identifier) for item in batch]
731 nprompt = [item.nprompt for item in batch] 723 nprompt = [item.nprompt for item in batch]
732 724
733 images = pipeline( 725 images = pipeline(
@@ -875,7 +867,7 @@ def main():
875 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 867 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
876 868
877 # Get the text embedding for conditioning 869 # Get the text embedding for conditioning
878 encoder_hidden_states = encode_input_ids(batch["input_ids"]) 870 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
879 871
880 # Predict the noise residual 872 # Predict the noise residual
881 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 873 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -974,7 +966,7 @@ def main():
974 966
975 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 967 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
976 968
977 encoder_hidden_states = encode_input_ids(batch["input_ids"]) 969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
978 970
979 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 971 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
980 972