summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a4c47b..a29c507 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -442,6 +442,12 @@ def main():
442 mixed_precision=args.mixed_precision 442 mixed_precision=args.mixed_precision
443 ) 443 )
444 444
445 weight_dtype = torch.float32
446 if args.mixed_precision == "fp16":
447 weight_dtype = torch.float16
448 elif args.mixed_precision == "bf16":
449 weight_dtype = torch.bfloat16
450
445 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 451 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
446 452
447 if args.seed is None: 453 if args.seed is None:
@@ -495,12 +501,6 @@ def main():
495 else: 501 else:
496 optimizer_class = torch.optim.AdamW 502 optimizer_class = torch.optim.AdamW
497 503
498 weight_dtype = torch.float32
499 if args.mixed_precision == "fp16":
500 weight_dtype = torch.float16
501 elif args.mixed_precision == "bf16":
502 weight_dtype = torch.bfloat16
503
504 trainer = partial( 504 trainer = partial(
505 train, 505 train,
506 accelerator=accelerator, 506 accelerator=accelerator,