diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 69d9c7f..8f266e0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from models.clip.prompt import PromptProcessor | ||
28 | 29 | ||
29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
30 | 31 | ||
@@ -152,7 +153,7 @@ def parse_args(): | |||
152 | parser.add_argument( | 153 | parser.add_argument( |
153 | "--lr_scheduler", | 154 | "--lr_scheduler", |
154 | type=str, | 155 | type=str, |
155 | default="cosine", | 156 | default="cosine_with_restarts", |
156 | help=( | 157 | help=( |
157 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
158 | ' "constant", "constant_with_warmup"]' | 159 | ' "constant", "constant_with_warmup"]' |
@@ -516,6 +517,8 @@ def main(): | |||
516 | unet = UNet2DConditionModel.from_pretrained( | 517 | unet = UNet2DConditionModel.from_pretrained( |
517 | args.pretrained_model_name_or_path, subfolder='unet') | 518 | args.pretrained_model_name_or_path, subfolder='unet') |
518 | 519 | ||
520 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
521 | |||
519 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
520 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
521 | 524 | ||
@@ -594,7 +597,7 @@ def main(): | |||
594 | pixel_values = torch.stack(pixel_values) | 597 | pixel_values = torch.stack(pixel_values) |
595 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 598 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
596 | 599 | ||
597 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 600 | input_ids = prompt_processor.unify_input_ids(input_ids) |
598 | 601 | ||
599 | batch = { | 602 | batch = { |
600 | "prompts": prompts, | 603 | "prompts": prompts, |
@@ -607,7 +610,7 @@ def main(): | |||
607 | datamodule = CSVDataModule( | 610 | datamodule = CSVDataModule( |
608 | data_file=args.train_data_file, | 611 | data_file=args.train_data_file, |
609 | batch_size=args.train_batch_size, | 612 | batch_size=args.train_batch_size, |
610 | tokenizer=tokenizer, | 613 | prompt_processor=prompt_processor, |
611 | instance_identifier=args.instance_identifier, | 614 | instance_identifier=args.instance_identifier, |
612 | class_identifier=args.class_identifier, | 615 | class_identifier=args.class_identifier, |
613 | class_subdir="cls", | 616 | class_subdir="cls", |
@@ -678,8 +681,8 @@ def main(): | |||
678 | optimizer=optimizer, | 681 | optimizer=optimizer, |
679 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 682 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
680 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 683 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
681 | num_cycles=args.lr_cycles or math.ceil( | 684 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
682 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 685 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
683 | ) | 686 | ) |
684 | else: | 687 | else: |
685 | lr_scheduler = get_scheduler( | 688 | lr_scheduler = get_scheduler( |
@@ -794,7 +797,7 @@ def main(): | |||
794 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 797 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
795 | 798 | ||
796 | # Get the text embedding for conditioning | 799 | # Get the text embedding for conditioning |
797 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 800 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
798 | 801 | ||
799 | # Predict the noise residual | 802 | # Predict the noise residual |
800 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 803 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -885,7 +888,7 @@ def main(): | |||
885 | 888 | ||
886 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 889 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
887 | 890 | ||
888 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 891 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
889 | 892 | ||
890 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 893 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
891 | 894 | ||