diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 28 |
1 files changed, 10 insertions, 18 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -26,6 +26,7 @@ from slugify import slugify | |||
| 26 | from schedulers.scheduling_euler_a import EulerAScheduler | 26 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 28 | from data.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 29 | from models.clip.prompt import PromptProcessor | ||
| 29 | 30 | ||
| 30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| 31 | 32 | ||
| @@ -147,7 +148,7 @@ def parse_args(): | |||
| 147 | parser.add_argument( | 148 | parser.add_argument( |
| 148 | "--learning_rate_text", | 149 | "--learning_rate_text", |
| 149 | type=float, | 150 | type=float, |
| 150 | default=1e-6, | 151 | default=5e-6, |
| 151 | help="Initial learning rate (after the potential warmup period) to use.", | 152 | help="Initial learning rate (after the potential warmup period) to use.", |
| 152 | ) | 153 | ) |
| 153 | parser.add_argument( | 154 | parser.add_argument( |
| @@ -470,7 +471,7 @@ class Checkpointer: | |||
| 470 | for i in range(self.sample_batches): | 471 | for i in range(self.sample_batches): |
| 471 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 472 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 472 | prompt = [ | 473 | prompt = [ |
| 473 | [p.format(self.instance_identifier) for p in prompt] | 474 | prompt.format(self.instance_identifier) |
| 474 | for batch in batches | 475 | for batch in batches |
| 475 | for prompt in batch["prompts"] | 476 | for prompt in batch["prompts"] |
| 476 | ][:self.sample_batch_size] | 477 | ][:self.sample_batch_size] |
| @@ -573,6 +574,8 @@ def main(): | |||
| 573 | device=accelerator.device | 574 | device=accelerator.device |
| 574 | ) if args.use_ema else None | 575 | ) if args.use_ema else None |
| 575 | 576 | ||
| 577 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 578 | |||
| 576 | if args.gradient_checkpointing: | 579 | if args.gradient_checkpointing: |
| 577 | unet.enable_gradient_checkpointing() | 580 | unet.enable_gradient_checkpointing() |
| 578 | 581 | ||
| @@ -663,7 +666,7 @@ def main(): | |||
| 663 | pixel_values = torch.stack(pixel_values) | 666 | pixel_values = torch.stack(pixel_values) |
| 664 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 667 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 665 | 668 | ||
| 666 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 669 | input_ids = prompt_processor.unify_input_ids(input_ids) |
| 667 | 670 | ||
| 668 | batch = { | 671 | batch = { |
| 669 | "prompts": prompts, | 672 | "prompts": prompts, |
| @@ -673,21 +676,10 @@ def main(): | |||
| 673 | } | 676 | } |
| 674 | return batch | 677 | return batch |
| 675 | 678 | ||
| 676 | def encode_input_ids(input_ids): | ||
| 677 | text_embeddings = [] | ||
| 678 | |||
| 679 | for ids in input_ids: | ||
| 680 | embeddings = text_encoder(ids)[0] | ||
| 681 | embeddings = embeddings.reshape((1, -1, 768)) | ||
| 682 | text_embeddings.append(embeddings) | ||
| 683 | |||
| 684 | text_embeddings = torch.cat(text_embeddings) | ||
| 685 | return text_embeddings | ||
| 686 | |||
| 687 | datamodule = CSVDataModule( | 679 | datamodule = CSVDataModule( |
| 688 | data_file=args.train_data_file, | 680 | data_file=args.train_data_file, |
| 689 | batch_size=args.train_batch_size, | 681 | batch_size=args.train_batch_size, |
| 690 | tokenizer=tokenizer, | 682 | prompt_processor=prompt_processor, |
| 691 | instance_identifier=args.instance_identifier, | 683 | instance_identifier=args.instance_identifier, |
| 692 | class_identifier=args.class_identifier, | 684 | class_identifier=args.class_identifier, |
| 693 | class_subdir="cls", | 685 | class_subdir="cls", |
| @@ -727,7 +719,7 @@ def main(): | |||
| 727 | with torch.inference_mode(): | 719 | with torch.inference_mode(): |
| 728 | for batch in batched_data: | 720 | for batch in batched_data: |
| 729 | image_name = [item.class_image_path for item in batch] | 721 | image_name = [item.class_image_path for item in batch] |
| 730 | prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] | 722 | prompt = [item.prompt.format(args.class_identifier) for item in batch] |
| 731 | nprompt = [item.nprompt for item in batch] | 723 | nprompt = [item.nprompt for item in batch] |
| 732 | 724 | ||
| 733 | images = pipeline( | 725 | images = pipeline( |
| @@ -875,7 +867,7 @@ def main(): | |||
| 875 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 867 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 876 | 868 | ||
| 877 | # Get the text embedding for conditioning | 869 | # Get the text embedding for conditioning |
| 878 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 870 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 879 | 871 | ||
| 880 | # Predict the noise residual | 872 | # Predict the noise residual |
| 881 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 873 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -974,7 +966,7 @@ def main(): | |||
| 974 | 966 | ||
| 975 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 967 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 976 | 968 | ||
| 977 | encoder_hidden_states = encode_input_ids(batch["input_ids"]) | 969 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 978 | 970 | ||
| 979 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 971 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 980 | 972 | ||
