From 2ddfbd65e482fa2361e8ba41b657656f825c9143 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 18 Oct 2022 18:08:32 +0200 Subject: Adapted other scripts for new prompt processing --- textual_inversion.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 69d9c7f..8f266e0 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -25,6 +25,7 @@ from slugify import slugify from schedulers.scheduling_euler_a import EulerAScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -152,7 +153,7 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine", + default="cosine_with_restarts", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' @@ -516,6 +517,8 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder='unet') + prompt_processor = PromptProcessor(tokenizer, text_encoder) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -594,7 +597,7 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = prompt_processor.unify_input_ids(input_ids) batch = { "prompts": prompts, @@ -607,7 +610,7 @@ def main(): datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, - tokenizer=tokenizer, + prompt_processor=prompt_processor, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", @@ -678,8 +681,8 @@ def main(): optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_cycles or math.ceil( - ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2), + num_cycles=args.lr_cycles or math.ceil(math.sqrt( + ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), ) else: lr_scheduler = get_scheduler( @@ -794,7 +797,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -885,7 +888,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf