summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /train_dreambooth.py
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0fe590f..fbbe6c2 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler
27from training.lr import LRFinder 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer 30from models.clip.tokenizer import MultiCLIPTokenizer
32 31
33logger = get_logger(__name__) 32logger = get_logger(__name__)
@@ -690,8 +689,6 @@ def main():
690 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 689 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
691 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 690 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
692 691
693 prompt_processor = PromptProcessor(tokenizer, text_encoder)
694
695 if args.scale_lr: 692 if args.scale_lr:
696 args.learning_rate = ( 693 args.learning_rate = (
697 args.learning_rate * args.gradient_accumulation_steps * 694 args.learning_rate * args.gradient_accumulation_steps *
@@ -751,7 +748,7 @@ def main():
751 datamodule = VlpnDataModule( 748 datamodule = VlpnDataModule(
752 data_file=args.train_data_file, 749 data_file=args.train_data_file,
753 batch_size=args.train_batch_size, 750 batch_size=args.train_batch_size,
754 prompt_processor=prompt_processor, 751 tokenizer=tokenizer,
755 class_subdir=args.class_image_dir, 752 class_subdir=args.class_image_dir,
756 num_class_images=args.num_class_images, 753 num_class_images=args.num_class_images,
757 size=args.resolution, 754 size=args.resolution,
@@ -876,7 +873,7 @@ def main():
876 vae, 873 vae,
877 noise_scheduler, 874 noise_scheduler,
878 unet, 875 unet,
879 prompt_processor, 876 text_encoder,
880 args.num_class_images, 877 args.num_class_images,
881 args.prior_loss_weight, 878 args.prior_loss_weight,
882 args.seed, 879 args.seed,