diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -442,6 +442,12 @@ def main(): | |||
442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
443 | ) | 443 | ) |
444 | 444 | ||
445 | weight_dtype = torch.float32 | ||
446 | if args.mixed_precision == "fp16": | ||
447 | weight_dtype = torch.float16 | ||
448 | elif args.mixed_precision == "bf16": | ||
449 | weight_dtype = torch.bfloat16 | ||
450 | |||
445 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
446 | 452 | ||
447 | if args.seed is None: | 453 | if args.seed is None: |
@@ -495,12 +501,6 @@ def main(): | |||
495 | else: | 501 | else: |
496 | optimizer_class = torch.optim.AdamW | 502 | optimizer_class = torch.optim.AdamW |
497 | 503 | ||
498 | weight_dtype = torch.float32 | ||
499 | if args.mixed_precision == "fp16": | ||
500 | weight_dtype = torch.float16 | ||
501 | elif args.mixed_precision == "bf16": | ||
502 | weight_dtype = torch.bfloat16 | ||
503 | |||
504 | trainer = partial( | 504 | trainer = partial( |
505 | train, | 505 | train, |
506 | accelerator=accelerator, | 506 | accelerator=accelerator, |