diff options
author | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
commit | 2469501c3951a9ed86c820cddf7b32144a4a1c8d (patch) | |
tree | 9820efaa12fd31670616c1fd9da3e6bb06580aaf /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.gz textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.bz2 textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.zip |
Move Accelerator preparation into strategy
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 7aa4960..451b61b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -159,7 +159,7 @@ def parse_args(): | |||
159 | parser.add_argument( | 159 | parser.add_argument( |
160 | "--tag_dropout", | 160 | "--tag_dropout", |
161 | type=float, | 161 | type=float, |
162 | default=0.1, | 162 | default=0, |
163 | help="Tag dropout probability.", | 163 | help="Tag dropout probability.", |
164 | ) | 164 | ) |
165 | parser.add_argument( | 165 | parser.add_argument( |
@@ -407,7 +407,7 @@ def parse_args(): | |||
407 | ) | 407 | ) |
408 | parser.add_argument( | 408 | parser.add_argument( |
409 | "--emb_decay", | 409 | "--emb_decay", |
410 | default=1e-2, | 410 | default=10, |
411 | type=float, | 411 | type=float, |
412 | help="Embedding decay factor." | 412 | help="Embedding decay factor." |
413 | ) | 413 | ) |
@@ -597,7 +597,7 @@ def main(): | |||
597 | 597 | ||
598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): | 598 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
599 | if len(placeholder_tokens) == 1: | 599 | if len(placeholder_tokens) == 1: |
600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") | 600 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
601 | else: | 601 | else: |
602 | sample_output_dir = output_dir.joinpath("samples") | 602 | sample_output_dir = output_dir.joinpath("samples") |
603 | 603 | ||