diff options
author | Volpeon <git@volpeon.ink> | 2023-01-17 07:20:45 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-17 07:20:45 +0100 |
commit | 5821523a524190490a287c5e2aacb6e72cc3a4cf (patch) | |
tree | c0eac536c754f078683be6d59893ad23d70baf51 /train_dreambooth.py | |
parent | Training update (diff) | |
download | textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2 textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip |
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d722e68..48bdcf8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -14,8 +14,7 @@ from slugify import slugify | |||
14 | 14 | ||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, get_models |
18 | from training.strategy.ti import textual_inversion_strategy | ||
19 | from training.strategy.dreambooth import dreambooth_strategy | 18 | from training.strategy.dreambooth import dreambooth_strategy |
20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
21 | from training.util import save_args | 20 | from training.util import save_args |
@@ -610,7 +609,7 @@ def main(): | |||
610 | ) | 609 | ) |
611 | 610 | ||
612 | trainer( | 611 | trainer( |
613 | callbacks_fn=dreambooth_strategy, | 612 | strategy=dreambooth_strategy, |
614 | project="dreambooth", | 613 | project="dreambooth", |
615 | train_dataloader=datamodule.train_dataloader, | 614 | train_dataloader=datamodule.train_dataloader, |
616 | val_dataloader=datamodule.val_dataloader, | 615 | val_dataloader=datamodule.val_dataloader, |