From 9ea20241bbeb2f32199067096272e13647c512eb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 8 Feb 2023 07:27:55 +0100 Subject: Fixed Lora training --- train_dreambooth.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -442,6 +442,12 @@ def main(): mixed_precision=args.mixed_precision ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: @@ -495,12 +501,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - trainer = partial( train, accelerator=accelerator, -- cgit v1.2.3-54-g00ecf