From 8fbc878aec9a7d1bab510491547d0753fe617975 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 09:56:15 +0200 Subject: Update --- train_lora.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 39bf455..0b26965 100644 --- a/train_lora.py +++ b/train_lora.py @@ -241,6 +241,12 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--pti_gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) parser.add_argument( "--lora_r", type=int, @@ -475,6 +481,12 @@ def parse_args(): default=1, help="Batch size (per device) for the training dataloader." ) + parser.add_argument( + "--pti_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) parser.add_argument( "--sample_steps", type=int, @@ -694,8 +706,8 @@ def main(): args.train_batch_size * accelerator.num_processes ) args.learning_rate_pti = ( - args.learning_rate_pti * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_pti * args.pti_gradient_accumulation_steps * + args.pti_batch_size * accelerator.num_processes ) if args.find_lr: @@ -808,7 +820,6 @@ def main(): guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, - gradient_accumulation_steps=args.gradient_accumulation_steps, offset_noise_strength=args.offset_noise_strength, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, @@ -820,7 +831,6 @@ def main(): create_datamodule = partial( VlpnDataModule, data_file=args.train_data_file, - batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, with_guidance=args.guidance_scale != 0, @@ -843,7 +853,6 @@ def main(): create_lr_scheduler = partial( get_scheduler, args.lr_scheduler, - gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, @@ -863,6 +872,7 @@ def main(): pti_sample_output_dir = pti_output_dir / "samples" pti_datamodule = create_datamodule( + batch_size=args.pti_batch_size, filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), ) pti_datamodule.setup() @@ -872,7 +882,7 @@ def main(): if num_pti_epochs is None: num_pti_epochs = math.ceil( args.num_pti_steps / len(pti_datamodule.train_dataset) - ) * args.gradient_accumulation_steps + ) * args.pti_gradient_accumulation_steps pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) pti_optimizer = create_optimizer( @@ -886,6 +896,7 @@ def main(): ) pti_lr_scheduler = create_lr_scheduler( + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, optimizer=pti_optimizer, num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), train_epochs=num_pti_epochs, @@ -893,12 +904,13 @@ def main(): metrics = trainer( strategy=textual_inversion_strategy, - project="ti", + project="pti", train_dataloader=pti_datamodule.train_dataloader, val_dataloader=pti_datamodule.val_dataloader, optimizer=pti_optimizer, lr_scheduler=pti_lr_scheduler, num_train_epochs=num_pti_epochs, + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, # -- sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, @@ -920,6 +932,7 @@ def main(): lora_sample_output_dir = lora_output_dir / "samples" lora_datamodule = create_datamodule( + batch_size=args.train_batch_size, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), ) lora_datamodule.setup() @@ -954,6 +967,7 @@ def main(): ) lora_lr_scheduler = create_lr_scheduler( + gradient_accumulation_steps=args.gradient_accumulation_steps, optimizer=lora_optimizer, num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), train_epochs=num_train_epochs, @@ -967,6 +981,7 @@ def main(): optimizer=lora_optimizer, lr_scheduler=lora_lr_scheduler, num_train_epochs=num_train_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, # -- sample_output_dir=lora_sample_output_dir, checkpoint_output_dir=lora_checkpoint_output_dir, -- cgit v1.2.3-54-g00ecf