diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 23:09:14 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 23:09:14 +0100 |
commit | 6ecfdb73d150c5a596722ec3234e53f4796fbc78 (patch) | |
tree | 797bc01768f71a74f944bf1bf18e9bf62665ee4e /train_ti.py | |
parent | Reverted modularization mostly (diff) | |
download | textual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.tar.gz textual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.tar.bz2 textual-inversion-diff-6ecfdb73d150c5a596722ec3234e53f4796fbc78.zip |
Unified training script structure
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 61195f6..d2ca7eb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -492,7 +492,7 @@ def parse_args(): | |||
492 | class Checkpointer(CheckpointerBase): | 492 | class Checkpointer(CheckpointerBase): |
493 | def __init__( | 493 | def __init__( |
494 | self, | 494 | self, |
495 | weight_dtype, | 495 | weight_dtype: torch.dtype, |
496 | accelerator: Accelerator, | 496 | accelerator: Accelerator, |
497 | vae: AutoencoderKL, | 497 | vae: AutoencoderKL, |
498 | unet: UNet2DConditionModel, | 498 | unet: UNet2DConditionModel, |
@@ -587,7 +587,9 @@ def main(): | |||
587 | 587 | ||
588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
589 | 589 | ||
590 | args.seed = args.seed or (torch.random.seed() >> 32) | 590 | if args.seed is None: |
591 | args.seed = torch.random.seed() >> 32 | ||
592 | |||
591 | set_seed(args.seed) | 593 | set_seed(args.seed) |
592 | 594 | ||
593 | save_args(basepath, args) | 595 | save_args(basepath, args) |
@@ -622,7 +624,8 @@ def main(): | |||
622 | num_vectors=args.num_vectors | 624 | num_vectors=args.num_vectors |
623 | ) | 625 | ) |
624 | 626 | ||
625 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 627 | if len(placeholder_token_ids) != 0: |
628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | ||
626 | 629 | ||
627 | if args.use_ema: | 630 | if args.use_ema: |
628 | ema_embeddings = EMAModel( | 631 | ema_embeddings = EMAModel( |