summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
committerVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
commit2ddfbd65e482fa2361e8ba41b657656f825c9143 (patch)
tree41cc82e23d82dd620c81f2715a50969b832e9bda /textual_inversion.py
parentImproved prompt handling (diff)
downloadtextual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.gz
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.bz2
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.zip
Adapted other scripts for new prompt processing
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 69d9c7f..8f266e0 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor
28 29
29logger = get_logger(__name__) 30logger = get_logger(__name__)
30 31
@@ -152,7 +153,7 @@ def parse_args():
152 parser.add_argument( 153 parser.add_argument(
153 "--lr_scheduler", 154 "--lr_scheduler",
154 type=str, 155 type=str,
155 default="cosine", 156 default="cosine_with_restarts",
156 help=( 157 help=(
157 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
158 ' "constant", "constant_with_warmup"]' 159 ' "constant", "constant_with_warmup"]'
@@ -516,6 +517,8 @@ def main():
516 unet = UNet2DConditionModel.from_pretrained( 517 unet = UNet2DConditionModel.from_pretrained(
517 args.pretrained_model_name_or_path, subfolder='unet') 518 args.pretrained_model_name_or_path, subfolder='unet')
518 519
520 prompt_processor = PromptProcessor(tokenizer, text_encoder)
521
519 if args.gradient_checkpointing: 522 if args.gradient_checkpointing:
520 unet.enable_gradient_checkpointing() 523 unet.enable_gradient_checkpointing()
521 524
@@ -594,7 +597,7 @@ def main():
594 pixel_values = torch.stack(pixel_values) 597 pixel_values = torch.stack(pixel_values)
595 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 598 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
596 599
597 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 600 input_ids = prompt_processor.unify_input_ids(input_ids)
598 601
599 batch = { 602 batch = {
600 "prompts": prompts, 603 "prompts": prompts,
@@ -607,7 +610,7 @@ def main():
607 datamodule = CSVDataModule( 610 datamodule = CSVDataModule(
608 data_file=args.train_data_file, 611 data_file=args.train_data_file,
609 batch_size=args.train_batch_size, 612 batch_size=args.train_batch_size,
610 tokenizer=tokenizer, 613 prompt_processor=prompt_processor,
611 instance_identifier=args.instance_identifier, 614 instance_identifier=args.instance_identifier,
612 class_identifier=args.class_identifier, 615 class_identifier=args.class_identifier,
613 class_subdir="cls", 616 class_subdir="cls",
@@ -678,8 +681,8 @@ def main():
678 optimizer=optimizer, 681 optimizer=optimizer,
679 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 682 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
680 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 683 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
681 num_cycles=args.lr_cycles or math.ceil( 684 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
682 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 685 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
683 ) 686 )
684 else: 687 else:
685 lr_scheduler = get_scheduler( 688 lr_scheduler = get_scheduler(
@@ -794,7 +797,7 @@ def main():
794 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 797 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
795 798
796 # Get the text embedding for conditioning 799 # Get the text embedding for conditioning
797 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 800 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
798 801
799 # Predict the noise residual 802 # Predict the noise residual
800 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 803 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -885,7 +888,7 @@ def main():
885 888
886 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 889 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
887 890
888 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 891 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
889 892
890 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 893 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
891 894