summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-31 14:54:15 +0200
committerVolpeon <git@volpeon.ink>2023-03-31 14:54:15 +0200
commit5acae38f9b995fbaeb42a1504cce88bd18154f12 (patch)
tree28abdb148fc133782fb5ee55b157cf1b12327c9d /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.tar.gz
textual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.tar.bz2
textual-inversion-diff-5acae38f9b995fbaeb42a1504cce88bd18154f12.zip
Fix
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 7900fbd..b182a72 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -766,9 +766,10 @@ def main():
766 datamodule.setup() 766 datamodule.setup()
767 767
768 num_train_epochs = args.num_train_epochs 768 num_train_epochs = args.num_train_epochs
769 sample_frequency = args.sample_frequency
769 if num_train_epochs is None: 770 if num_train_epochs is None:
770 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset)) 771 num_train_epochs = math.ceil(args.num_train_steps / len(datamodule.train_dataset))
771 sample_frequency = math.ceil(num_train_epochs * (args.sample_frequency / args.num_train_steps)) 772 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
772 773
773 optimizer = create_optimizer( 774 optimizer = create_optimizer(
774 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 775 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),