From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 07:27:55 +0100 Subject: Fixed Lora training --- train_ti.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py @@ -513,6 +513,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -564,12 +570,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - checkpoint_output_dir = output_dir.joinpath("checkpoints") trainer = partial( -- cgit v1.2.3-54-g00ecf