From 85953e46c0d36658293b1cd39e26f5f550b173f8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 21:41:53 +0200 Subject: Fix --- train_lora.py | 77 ++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 35 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 5b0a292..9f17495 100644 --- a/train_lora.py +++ b/train_lora.py @@ -444,6 +444,12 @@ def parse_args(): default=1, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--pti_sample_frequency", + type=int, + default=1, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_image_size", type=int, @@ -887,47 +893,48 @@ def main(): pti_datamodule.setup() num_pti_epochs = args.num_pti_epochs - pti_sample_frequency = args.sample_frequency + pti_sample_frequency = args.pti_sample_frequency if num_pti_epochs is None: num_pti_epochs = math.ceil( args.num_pti_steps / len(pti_datamodule.train_dataset) ) * args.pti_gradient_accumulation_steps pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) - pti_optimizer = create_optimizer( - [ - { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_pti, - "weight_decay": 0, - }, - ] - ) - - pti_lr_scheduler = create_lr_scheduler( - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - optimizer=pti_optimizer, - num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), - train_epochs=num_pti_epochs, - ) - - metrics = trainer( - strategy=lora_strategy, - pti_mode=True, - project="pti", - train_dataloader=pti_datamodule.train_dataloader, - val_dataloader=pti_datamodule.val_dataloader, - optimizer=pti_optimizer, - lr_scheduler=pti_lr_scheduler, - num_train_epochs=num_pti_epochs, - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - # -- - sample_output_dir=pti_sample_output_dir, - checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=math.inf, - ) - - plot_metrics(metrics, pti_output_dir / "lr.png") + if num_pti_epochs > 0: + pti_optimizer = create_optimizer( + [ + { + "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "lr": args.learning_rate_pti, + "weight_decay": 0, + }, + ] + ) + + pti_lr_scheduler = create_lr_scheduler( + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, + optimizer=pti_optimizer, + num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), + train_epochs=num_pti_epochs, + ) + + metrics = trainer( + strategy=lora_strategy, + pti_mode=True, + project="pti", + train_dataloader=pti_datamodule.train_dataloader, + val_dataloader=pti_datamodule.val_dataloader, + optimizer=pti_optimizer, + lr_scheduler=pti_lr_scheduler, + num_train_epochs=num_pti_epochs, + gradient_accumulation_steps=args.pti_gradient_accumulation_steps, + # -- + sample_output_dir=pti_sample_output_dir, + checkpoint_output_dir=pti_checkpoint_output_dir, + sample_frequency=pti_sample_frequency, + ) + + plot_metrics(metrics, pti_output_dir / "lr.png") # LORA # -------------------------------------------------------------------------------- -- cgit v1.2.3-54-g00ecf