From 7b149930bb53b93db74106ad20a30abf4b114f9b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 13:49:35 +0100 Subject: Removed PromptProcessor, modularized training loop --- train_dreambooth.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 0fe590f..fbbe6c2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -27,7 +27,6 @@ from training.common import loss_step, generate_class_images, get_scheduler from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings -from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -690,8 +689,6 @@ def main(): text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - prompt_processor = PromptProcessor(tokenizer, text_encoder) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -751,7 +748,7 @@ def main(): datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - prompt_processor=prompt_processor, + tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, @@ -876,7 +873,7 @@ def main(): vae, noise_scheduler, unet, - prompt_processor, + text_encoder, args.num_class_images, args.prior_loss_weight, args.seed, -- cgit v1.2.3-54-g00ecf