diff options
author | Volpeon <git@volpeon.ink> | 2022-10-18 18:08:32 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-18 18:08:32 +0200 |
commit | 2ddfbd65e482fa2361e8ba41b657656f825c9143 (patch) | |
tree | 41cc82e23d82dd620c81f2715a50969b832e9bda | |
parent | Improved prompt handling (diff) | |
download | textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.gz textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.bz2 textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.zip |
Adapted other scripts for new prompt processing
-rw-r--r-- | dreambooth.py | 17 | ||||
-rw-r--r-- | dreambooth_plus.py | 6 | ||||
-rw-r--r-- | textual_inversion.py | 17 |
3 files changed, 23 insertions, 17 deletions
diff --git a/dreambooth.py b/dreambooth.py index 770ad38..9786e0f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from models.clip.prompt import PromptProcessor | ||
28 | 29 | ||
29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
30 | 31 | ||
@@ -141,7 +142,7 @@ def parse_args(): | |||
141 | parser.add_argument( | 142 | parser.add_argument( |
142 | "--lr_scheduler", | 143 | "--lr_scheduler", |
143 | type=str, | 144 | type=str, |
144 | default="cosine", | 145 | default="cosine_with_restarts", |
145 | help=( | 146 | help=( |
146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 147 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
147 | ' "constant", "constant_with_warmup"]' | 148 | ' "constant", "constant_with_warmup"]' |
@@ -494,6 +495,8 @@ def main(): | |||
494 | device=accelerator.device | 495 | device=accelerator.device |
495 | ) if args.use_ema else None | 496 | ) if args.use_ema else None |
496 | 497 | ||
498 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
499 | |||
497 | if args.gradient_checkpointing: | 500 | if args.gradient_checkpointing: |
498 | unet.enable_gradient_checkpointing() | 501 | unet.enable_gradient_checkpointing() |
499 | 502 | ||
@@ -557,7 +560,7 @@ def main(): | |||
557 | pixel_values = torch.stack(pixel_values) | 560 | pixel_values = torch.stack(pixel_values) |
558 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 561 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
559 | 562 | ||
560 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 563 | input_ids = prompt_processor.unify_input_ids(input_ids) |
561 | 564 | ||
562 | batch = { | 565 | batch = { |
563 | "prompts": prompts, | 566 | "prompts": prompts, |
@@ -570,7 +573,7 @@ def main(): | |||
570 | datamodule = CSVDataModule( | 573 | datamodule = CSVDataModule( |
571 | data_file=args.train_data_file, | 574 | data_file=args.train_data_file, |
572 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
573 | tokenizer=tokenizer, | 576 | prompt_processor=prompt_processor, |
574 | instance_identifier=args.instance_identifier, | 577 | instance_identifier=args.instance_identifier, |
575 | class_identifier=args.class_identifier, | 578 | class_identifier=args.class_identifier, |
576 | class_subdir="cls", | 579 | class_subdir="cls", |
@@ -641,8 +644,8 @@ def main(): | |||
641 | optimizer=optimizer, | 644 | optimizer=optimizer, |
642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 645 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 646 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
644 | num_cycles=args.lr_cycles or math.ceil( | 647 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
645 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 648 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
646 | ) | 649 | ) |
647 | else: | 650 | else: |
648 | lr_scheduler = get_scheduler( | 651 | lr_scheduler = get_scheduler( |
@@ -756,7 +759,7 @@ def main(): | |||
756 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 759 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
757 | 760 | ||
758 | # Get the text embedding for conditioning | 761 | # Get the text embedding for conditioning |
759 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 762 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
760 | 763 | ||
761 | # Predict the noise residual | 764 | # Predict the noise residual |
762 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 765 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -832,7 +835,7 @@ def main(): | |||
832 | 835 | ||
833 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 836 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
834 | 837 | ||
835 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 838 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
836 | 839 | ||
837 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 840 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
838 | 841 | ||
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index fa3a22b..06ff45b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -125,7 +125,7 @@ def parse_args(): | |||
125 | parser.add_argument( | 125 | parser.add_argument( |
126 | "--max_train_steps", | 126 | "--max_train_steps", |
127 | type=int, | 127 | type=int, |
128 | default=1400, | 128 | default=2400, |
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
130 | ) | 130 | ) |
131 | parser.add_argument( | 131 | parser.add_argument( |
@@ -752,8 +752,8 @@ def main(): | |||
752 | optimizer=optimizer, | 752 | optimizer=optimizer, |
753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 753 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 754 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
755 | num_cycles=args.lr_cycles or math.ceil( | 755 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 756 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
757 | ) | 757 | ) |
758 | else: | 758 | else: |
759 | lr_scheduler = get_scheduler( | 759 | lr_scheduler = get_scheduler( |
diff --git a/textual_inversion.py b/textual_inversion.py index 69d9c7f..8f266e0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,6 +25,7 @@ from slugify import slugify | |||
25 | from schedulers.scheduling_euler_a import EulerAScheduler | 25 | from schedulers.scheduling_euler_a import EulerAScheduler |
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | from models.clip.prompt import PromptProcessor | ||
28 | 29 | ||
29 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
30 | 31 | ||
@@ -152,7 +153,7 @@ def parse_args(): | |||
152 | parser.add_argument( | 153 | parser.add_argument( |
153 | "--lr_scheduler", | 154 | "--lr_scheduler", |
154 | type=str, | 155 | type=str, |
155 | default="cosine", | 156 | default="cosine_with_restarts", |
156 | help=( | 157 | help=( |
157 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
158 | ' "constant", "constant_with_warmup"]' | 159 | ' "constant", "constant_with_warmup"]' |
@@ -516,6 +517,8 @@ def main(): | |||
516 | unet = UNet2DConditionModel.from_pretrained( | 517 | unet = UNet2DConditionModel.from_pretrained( |
517 | args.pretrained_model_name_or_path, subfolder='unet') | 518 | args.pretrained_model_name_or_path, subfolder='unet') |
518 | 519 | ||
520 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
521 | |||
519 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
520 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
521 | 524 | ||
@@ -594,7 +597,7 @@ def main(): | |||
594 | pixel_values = torch.stack(pixel_values) | 597 | pixel_values = torch.stack(pixel_values) |
595 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 598 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) |
596 | 599 | ||
597 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 600 | input_ids = prompt_processor.unify_input_ids(input_ids) |
598 | 601 | ||
599 | batch = { | 602 | batch = { |
600 | "prompts": prompts, | 603 | "prompts": prompts, |
@@ -607,7 +610,7 @@ def main(): | |||
607 | datamodule = CSVDataModule( | 610 | datamodule = CSVDataModule( |
608 | data_file=args.train_data_file, | 611 | data_file=args.train_data_file, |
609 | batch_size=args.train_batch_size, | 612 | batch_size=args.train_batch_size, |
610 | tokenizer=tokenizer, | 613 | prompt_processor=prompt_processor, |
611 | instance_identifier=args.instance_identifier, | 614 | instance_identifier=args.instance_identifier, |
612 | class_identifier=args.class_identifier, | 615 | class_identifier=args.class_identifier, |
613 | class_subdir="cls", | 616 | class_subdir="cls", |
@@ -678,8 +681,8 @@ def main(): | |||
678 | optimizer=optimizer, | 681 | optimizer=optimizer, |
679 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 682 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
680 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 683 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
681 | num_cycles=args.lr_cycles or math.ceil( | 684 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
682 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), | 685 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), |
683 | ) | 686 | ) |
684 | else: | 687 | else: |
685 | lr_scheduler = get_scheduler( | 688 | lr_scheduler = get_scheduler( |
@@ -794,7 +797,7 @@ def main(): | |||
794 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 797 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
795 | 798 | ||
796 | # Get the text embedding for conditioning | 799 | # Get the text embedding for conditioning |
797 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 800 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
798 | 801 | ||
799 | # Predict the noise residual | 802 | # Predict the noise residual |
800 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 803 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -885,7 +888,7 @@ def main(): | |||
885 | 888 | ||
886 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 889 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
887 | 890 | ||
888 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 891 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
889 | 892 | ||
890 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 893 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
891 | 894 | ||