summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
commit89afcfda3f824cc44221e877182348f9b09687d2 (patch)
tree804b84322e5caa8fb861322ce6970bef4b532c61 /train_ti.py
parentExtended Dreambooth: Train TI tokens separately (diff)
downloadtextual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.gz
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.bz2
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.zip
Handle empty validation dataset
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py9
1 files changed, 3 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index 48a2333..a894ee7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -582,9 +582,6 @@ def main():
582 ) 582 )
583 datamodule.setup() 583 datamodule.setup()
584 584
585 train_dataloader = datamodule.train_dataloader
586 val_dataloader = datamodule.val_dataloader
587
588 if args.num_class_images != 0: 585 if args.num_class_images != 0:
589 generate_class_images( 586 generate_class_images(
590 accelerator, 587 accelerator,
@@ -623,7 +620,7 @@ def main():
623 lr_scheduler = get_scheduler( 620 lr_scheduler = get_scheduler(
624 args.lr_scheduler, 621 args.lr_scheduler,
625 optimizer=optimizer, 622 optimizer=optimizer,
626 num_training_steps_per_epoch=len(train_dataloader), 623 num_training_steps_per_epoch=len(datamodule.train_dataloader),
627 gradient_accumulation_steps=args.gradient_accumulation_steps, 624 gradient_accumulation_steps=args.gradient_accumulation_steps,
628 min_lr=args.lr_min_lr, 625 min_lr=args.lr_min_lr,
629 warmup_func=args.lr_warmup_func, 626 warmup_func=args.lr_warmup_func,
@@ -637,8 +634,8 @@ def main():
637 634
638 trainer( 635 trainer(
639 project="textual_inversion", 636 project="textual_inversion",
640 train_dataloader=train_dataloader, 637 train_dataloader=datamodule.train_dataloader,
641 val_dataloader=val_dataloader, 638 val_dataloader=datamodule.val_dataloader,
642 optimizer=optimizer, 639 optimizer=optimizer,
643 lr_scheduler=lr_scheduler, 640 lr_scheduler=lr_scheduler,
644 num_train_epochs=args.num_train_epochs, 641 num_train_epochs=args.num_train_epochs,