summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
commit5821523a524190490a287c5e2aacb6e72cc3a4cf (patch)
treec0eac536c754f078683be6d59893ad23d70baf51 /train_dreambooth.py
parentTraining update (diff)
downloadtextual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d722e68..48bdcf8 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -14,8 +14,7 @@ from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, get_models
18from training.strategy.ti import textual_inversion_strategy
19from training.strategy.dreambooth import dreambooth_strategy 18from training.strategy.dreambooth import dreambooth_strategy
20from training.optimization import get_scheduler 19from training.optimization import get_scheduler
21from training.util import save_args 20from training.util import save_args
@@ -610,7 +609,7 @@ def main():
610 ) 609 )
611 610
612 trainer( 611 trainer(
613 callbacks_fn=dreambooth_strategy, 612 strategy=dreambooth_strategy,
614 project="dreambooth", 613 project="dreambooth",
615 train_dataloader=datamodule.train_dataloader, 614 train_dataloader=datamodule.train_dataloader,
616 val_dataloader=datamodule.val_dataloader, 615 val_dataloader=datamodule.val_dataloader,