From 89afcfda3f824cc44221e877182348f9b09687d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:31:55 +0100 Subject: Handle empty validation dataset --- train_ti.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 48a2333..a894ee7 100644 --- a/train_ti.py +++ b/train_ti.py @@ -582,9 +582,6 @@ def main(): ) datamodule.setup() - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - if args.num_class_images != 0: generate_class_images( accelerator, @@ -623,7 +620,7 @@ def main(): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_training_steps_per_epoch=len(train_dataloader), + num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, @@ -637,8 +634,8 @@ def main(): trainer( project="textual_inversion", - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, -- cgit v1.2.3-54-g00ecf