From 6ecfdb73d150c5a596722ec3234e53f4796fbc78 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 23:09:14 +0100 Subject: Unified training script structure --- train_ti.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 61195f6..d2ca7eb 100644 --- a/train_ti.py +++ b/train_ti.py @@ -492,7 +492,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, - weight_dtype, + weight_dtype: torch.dtype, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, @@ -587,7 +587,9 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - args.seed = args.seed or (torch.random.seed() >> 32) + if args.seed is None: + args.seed = torch.random.seed() >> 32 + set_seed(args.seed) save_args(basepath, args) @@ -622,7 +624,8 @@ def main(): num_vectors=args.num_vectors ) - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + if len(placeholder_token_ids) != 0: + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") if args.use_ema: ema_embeddings = EMAModel( -- cgit v1.2.3-54-g00ecf