summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 61195f6..d2ca7eb 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -492,7 +492,7 @@ def parse_args():
492class Checkpointer(CheckpointerBase): 492class Checkpointer(CheckpointerBase):
493 def __init__( 493 def __init__(
494 self, 494 self,
495 weight_dtype, 495 weight_dtype: torch.dtype,
496 accelerator: Accelerator, 496 accelerator: Accelerator,
497 vae: AutoencoderKL, 497 vae: AutoencoderKL,
498 unet: UNet2DConditionModel, 498 unet: UNet2DConditionModel,
@@ -587,7 +587,9 @@ def main():
587 587
588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
589 589
590 args.seed = args.seed or (torch.random.seed() >> 32) 590 if args.seed is None:
591 args.seed = torch.random.seed() >> 32
592
591 set_seed(args.seed) 593 set_seed(args.seed)
592 594
593 save_args(basepath, args) 595 save_args(basepath, args)
@@ -622,7 +624,8 @@ def main():
622 num_vectors=args.num_vectors 624 num_vectors=args.num_vectors
623 ) 625 )
624 626
625 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") 627 if len(placeholder_token_ids) != 0:
628 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}")
626 629
627 if args.use_ema: 630 if args.use_ema:
628 ema_embeddings = EMAModel( 631 ema_embeddings = EMAModel(