From 306f2bfb620e6882737658bd3694c79365d75e4b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 15:23:40 +0200 Subject: Improved prompt handling --- dreambooth_plus.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index ae31377..fa3a22b 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -26,6 +26,7 @@ from slugify import slugify from schedulers.scheduling_euler_a import EulerAScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -147,7 +148,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=1e-6, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -470,7 +471,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] prompt = [ - [p.format(self.instance_identifier) for p in prompt] + prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"] ][:self.sample_batch_size] @@ -573,6 +574,8 @@ def main(): device=accelerator.device ) if args.use_ema else None + prompt_processor = PromptProcessor(tokenizer, text_encoder) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -663,7 +666,7 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, @@ -673,21 +676,10 @@ def main(): } return batch - def encode_input_ids(input_ids): - text_embeddings = [] - - for ids in input_ids: - embeddings = text_encoder(ids)[0] - embeddings = embeddings.reshape((1, -1, 768)) - text_embeddings.append(embeddings) - - text_embeddings = torch.cat(text_embeddings) - return text_embeddings - datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - tokenizer=tokenizer, + prompt_processor=prompt_processor, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", @@ -727,7 +719,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] + prompt = [item.prompt.format(args.class_identifier) for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -875,7 +867,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = encode_input_ids(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -974,7 +966,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = encode_input_ids(batch["input_ids"]) + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf