diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0fe590f..fbbe6c2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler | |||
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
30 | from models.clip.prompt import PromptProcessor | ||
31 | from models.clip.tokenizer import MultiCLIPTokenizer | 30 | from models.clip.tokenizer import MultiCLIPTokenizer |
32 | 31 | ||
33 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
@@ -690,8 +689,6 @@ def main(): | |||
690 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 689 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
691 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) | 690 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
692 | 691 | ||
693 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
694 | |||
695 | if args.scale_lr: | 692 | if args.scale_lr: |
696 | args.learning_rate = ( | 693 | args.learning_rate = ( |
697 | args.learning_rate * args.gradient_accumulation_steps * | 694 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -751,7 +748,7 @@ def main(): | |||
751 | datamodule = VlpnDataModule( | 748 | datamodule = VlpnDataModule( |
752 | data_file=args.train_data_file, | 749 | data_file=args.train_data_file, |
753 | batch_size=args.train_batch_size, | 750 | batch_size=args.train_batch_size, |
754 | prompt_processor=prompt_processor, | 751 | tokenizer=tokenizer, |
755 | class_subdir=args.class_image_dir, | 752 | class_subdir=args.class_image_dir, |
756 | num_class_images=args.num_class_images, | 753 | num_class_images=args.num_class_images, |
757 | size=args.resolution, | 754 | size=args.resolution, |
@@ -876,7 +873,7 @@ def main(): | |||
876 | vae, | 873 | vae, |
877 | noise_scheduler, | 874 | noise_scheduler, |
878 | unet, | 875 | unet, |
879 | prompt_processor, | 876 | text_encoder, |
880 | args.num_class_images, | 877 | args.num_class_images, |
881 | args.prior_loss_weight, | 878 | args.prior_loss_weight, |
882 | args.seed, | 879 | args.seed, |