diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-31 09:34:55 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-31 09:34:55 +0200 |
| commit | be3e05e47cded8487aaa787c54aa74770f9dcac8 (patch) | |
| tree | d7ba88cbe7f72d6e910dd2d7f1916bb198e701fd /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.gz textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.tar.bz2 textual-inversion-diff-be3e05e47cded8487aaa787c54aa74770f9dcac8.zip | |
Fix
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 9c4ad93..b7ea5f3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -761,7 +761,8 @@ def main(): | |||
| 761 | num_train_epochs = args.num_train_epochs | 761 | num_train_epochs = args.num_train_epochs |
| 762 | 762 | ||
| 763 | if num_train_epochs is None: | 763 | if num_train_epochs is None: |
| 764 | num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) | 764 | num_images = math.ceil(len(datamodule.train_dataset) / args.train_batch_size) * args.train_batch_size |
| 765 | num_train_epochs = math.ceil(args.num_train_steps / num_images) | ||
| 765 | 766 | ||
| 766 | optimizer = create_optimizer( | 767 | optimizer = create_optimizer( |
| 767 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 768 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
