From 2ddfbd65e482fa2361e8ba41b657656f825c9143 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 18:08:32 +0200 Subject: Adapted other scripts for new prompt processing --- dreambooth.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 770ad38..9786e0f 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -25,6 +25,7 @@ from slugify import slugify from schedulers.scheduling_euler_a import EulerAScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -141,7 +142,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine", + default="cosine_with_restarts", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -494,6 +495,8 @@ def main(): device=accelerator.device ) if args.use_ema else None + prompt_processor = PromptProcessor(tokenizer, text_encoder) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -557,7 +560,7 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, @@ -570,7 +573,7 @@ def main(): datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - tokenizer=tokenizer, + prompt_processor=prompt_processor, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", @@ -641,8 +644,8 @@ def main(): optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles or math.ceil( - ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), + num_cycles=args.lr_cycles or math.ceil(math.sqrt( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( @@ -756,7 +759,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -832,7 +835,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf