diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 28 |
1 files changed, 10 insertions, 18 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -26,6 +26,7 @@ from slugify import slugify | |||
26 | from schedulers.scheduling_euler_a import EulerAScheduler | 26 | from schedulers.scheduling_euler_a import EulerAScheduler |
27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
29 | from models.clip.prompt import PromptProcessor | ||
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
31 | 32 | ||
@@ -147,7 +148,7 @@ def parse_args(): | |||
147 | parser.add_argument( | 148 | parser.add_argument( |
148 | "--learning_rate_text", | 149 | "--learning_rate_text", |
149 | type=float, | 150 | type=float, |
150 | default=1e-6, | 151 | default=5e-6, |
151 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
152 | ) | 153 | ) |
153 | parser.add_argument( | 154 | parser.add_argument( |
@@ -470,7 +471,7 @@ class Checkpointer: | |||
470 | for i in range(self.sample_batches): | 471 | for i in range(self.sample_batches): |
471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 472 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
472 | prompt = [ | 473 | prompt = [ |
473 | [p.format(self.instance_identifier) for p in prompt] | 474 | prompt.format(self.instance_identifier) |
474 | for batch in batches | 475 | for batch in batches |
475 | for prompt in batch["prompts"] | 476 | for prompt in batch["prompts"] |
476 | ][:self.sample_batch_size] | 477 | ][:self.sample_batch_size] |
@@ -573,6 +574,8 @@ def main(): | |||
573 | device=accelerator.device | 574 | device=accelerator.device |
574 | ) if args.use_ema else None | 575 | ) if args.use_ema else None |
575 | 576 | ||
577 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
578 | |||
576 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
577 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
578 | 581 | ||
@@ -663,7 +666,7 @@ def main(): | |||
663 | pixel_values = torch.stack(pixel_values) | 666 | pixel_values = torch.stack(pixel_values) |
664 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 667 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
665 | 668 | ||
666 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 669 | input_ids = prompt_processor.unify_input_ids(input_ids) |
667 | 670 | ||
668 | batch = { | 671 | batch = { |
669 | "prompts": prompts, | 672 | "prompts": prompts, |
@@ -673,21 +676,10 @@ def main(): | |||
673 | } | 676 | } |
674 | return batch | 677 | return batch |
675 | 678 | ||
676 | def encode_input_ids(input_ids): | ||
677 | text_embeddings = [] | ||
678 | |||
679 | for ids in input_ids: | ||
680 | embeddings = text_encoder(ids)[0] | ||
681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
682 | text_embeddings.append(embeddings) | ||
683 | |||
684 | text_embeddings = torch.cat(text_embeddings) | ||
685 | return text_embeddings | ||
686 | |||
687 | datamodule = CSVDataModule( | 679 | datamodule = CSVDataModule( |
688 | data_file=args.train_data_file, | 680 | data_file=args.train_data_file, |
689 | batch_size=args.train_batch_size, | 681 | batch_size=args.train_batch_size, |
690 | tokenizer=tokenizer, | 682 | prompt_processor=prompt_processor, |
691 | instance_identifier=args.instance_identifier, | 683 | instance_identifier=args.instance_identifier, |
692 | class_identifier=args.class_identifier, | 684 | class_identifier=args.class_identifier, |
693 | class_subdir="cls", | 685 | class_subdir="cls", |
@@ -727,7 +719,7 @@ def main(): | |||
727 | with torch.inference_mode(): | 719 | with torch.inference_mode(): |
728 | for batch in batched_data: | 720 | for batch in batched_data: |
729 | image_name = [item.class_image_path for item in batch] | 721 | image_name = [item.class_image_path for item in batch] |
730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] | 722 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
731 | nprompt = [item.nprompt for item in batch] | 723 | nprompt = [item.nprompt for item in batch] |
732 | 724 | ||
733 | images = pipeline( | 725 | images = pipeline( |
@@ -875,7 +867,7 @@ def main(): | |||
875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 867 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
876 | 868 | ||
877 | # Get the text embedding for conditioning | 869 | # Get the text embedding for conditioning |
878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 870 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
879 | 871 | ||
880 | # Predict the noise residual | 872 | # Predict the noise residual |
881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 873 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -974,7 +966,7 @@ def main(): | |||
974 | 966 | ||
975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 967 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
976 | 968 | ||
977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
978 | 970 | ||
979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 971 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
980 | 972 | ||