diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 61195f6..d2ca7eb 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -492,7 +492,7 @@ def parse_args(): | |||
| 492 | class Checkpointer(CheckpointerBase): | 492 | class Checkpointer(CheckpointerBase): |
| 493 | def __init__( | 493 | def __init__( |
| 494 | self, | 494 | self, |
| 495 | weight_dtype, | 495 | weight_dtype: torch.dtype, |
| 496 | accelerator: Accelerator, | 496 | accelerator: Accelerator, |
| 497 | vae: AutoencoderKL, | 497 | vae: AutoencoderKL, |
| 498 | unet: UNet2DConditionModel, | 498 | unet: UNet2DConditionModel, |
| @@ -587,7 +587,9 @@ def main(): | |||
| 587 | 587 | ||
| 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 589 | 589 | ||
| 590 | args.seed = args.seed or (torch.random.seed() >> 32) | 590 | if args.seed is None: |
| 591 | args.seed = torch.random.seed() >> 32 | ||
| 592 | |||
| 591 | set_seed(args.seed) | 593 | set_seed(args.seed) |
| 592 | 594 | ||
| 593 | save_args(basepath, args) | 595 | save_args(basepath, args) |
| @@ -622,7 +624,8 @@ def main(): | |||
| 622 | num_vectors=args.num_vectors | 624 | num_vectors=args.num_vectors |
| 623 | ) | 625 | ) |
| 624 | 626 | ||
| 625 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 627 | if len(placeholder_token_ids) != 0: |
| 628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | ||
| 626 | 629 | ||
| 627 | if args.use_ema: | 630 | if args.use_ema: |
| 628 | ema_embeddings = EMAModel( | 631 | ema_embeddings = EMAModel( |
