summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
committerVolpeon <git@volpeon.ink>2022-10-18 18:08:32 +0200
commit2ddfbd65e482fa2361e8ba41b657656f825c9143 (patch)
tree41cc82e23d82dd620c81f2715a50969b832e9bda /dreambooth.py
parentImproved prompt handling (diff)
downloadtextual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.gz
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.tar.bz2
textual-inversion-diff-2ddfbd65e482fa2361e8ba41b657656f825c9143.zip
Adapted other scripts for new prompt processing
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 770ad38..9786e0f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -25,6 +25,7 @@ from slugify import slugify
25from schedulers.scheduling_euler_a import EulerAScheduler 25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28from models.clip.prompt import PromptProcessor
28 29
29logger = get_logger(__name__) 30logger = get_logger(__name__)
30 31
@@ -141,7 +142,7 @@ def parse_args():
141 parser.add_argument( 142 parser.add_argument(
142 "--lr_scheduler", 143 "--lr_scheduler",
143 type=str, 144 type=str,
144 default="cosine", 145 default="cosine_with_restarts",
145 help=( 146 help=(
146 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 147 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
147 ' "constant", "constant_with_warmup"]' 148 ' "constant", "constant_with_warmup"]'
@@ -494,6 +495,8 @@ def main():
494 device=accelerator.device 495 device=accelerator.device
495 ) if args.use_ema else None 496 ) if args.use_ema else None
496 497
498 prompt_processor = PromptProcessor(tokenizer, text_encoder)
499
497 if args.gradient_checkpointing: 500 if args.gradient_checkpointing:
498 unet.enable_gradient_checkpointing() 501 unet.enable_gradient_checkpointing()
499 502
@@ -557,7 +560,7 @@ def main():
557 pixel_values = torch.stack(pixel_values) 560 pixel_values = torch.stack(pixel_values)
558 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 561 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
559 562
560 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 563 input_ids = prompt_processor.unify_input_ids(input_ids)
561 564
562 batch = { 565 batch = {
563 "prompts": prompts, 566 "prompts": prompts,
@@ -570,7 +573,7 @@ def main():
570 datamodule = CSVDataModule( 573 datamodule = CSVDataModule(
571 data_file=args.train_data_file, 574 data_file=args.train_data_file,
572 batch_size=args.train_batch_size, 575 batch_size=args.train_batch_size,
573 tokenizer=tokenizer, 576 prompt_processor=prompt_processor,
574 instance_identifier=args.instance_identifier, 577 instance_identifier=args.instance_identifier,
575 class_identifier=args.class_identifier, 578 class_identifier=args.class_identifier,
576 class_subdir="cls", 579 class_subdir="cls",
@@ -641,8 +644,8 @@ def main():
641 optimizer=optimizer, 644 optimizer=optimizer,
642 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 645 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
643 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 646 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
644 num_cycles=args.lr_cycles or math.ceil( 647 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
645 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), 648 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
646 ) 649 )
647 else: 650 else:
648 lr_scheduler = get_scheduler( 651 lr_scheduler = get_scheduler(
@@ -756,7 +759,7 @@ def main():
756 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 759 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
757 760
758 # Get the text embedding for conditioning 761 # Get the text embedding for conditioning
759 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 762 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
760 763
761 # Predict the noise residual 764 # Predict the noise residual
762 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 765 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -832,7 +835,7 @@ def main():
832 835
833 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 836 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
834 837
835 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 838 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
836 839
837 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 840 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
838 841