diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -442,6 +442,12 @@ def main(): | |||
| 442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
| 443 | ) | 443 | ) |
| 444 | 444 | ||
| 445 | weight_dtype = torch.float32 | ||
| 446 | if args.mixed_precision == "fp16": | ||
| 447 | weight_dtype = torch.float16 | ||
| 448 | elif args.mixed_precision == "bf16": | ||
| 449 | weight_dtype = torch.bfloat16 | ||
| 450 | |||
| 445 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 446 | 452 | ||
| 447 | if args.seed is None: | 453 | if args.seed is None: |
| @@ -495,12 +501,6 @@ def main(): | |||
| 495 | else: | 501 | else: |
| 496 | optimizer_class = torch.optim.AdamW | 502 | optimizer_class = torch.optim.AdamW |
| 497 | 503 | ||
| 498 | weight_dtype = torch.float32 | ||
| 499 | if args.mixed_precision == "fp16": | ||
| 500 | weight_dtype = torch.float16 | ||
| 501 | elif args.mixed_precision == "bf16": | ||
| 502 | weight_dtype = torch.bfloat16 | ||
| 503 | |||
| 504 | trainer = partial( | 504 | trainer = partial( |
| 505 | train, | 505 | train, |
| 506 | accelerator=accelerator, | 506 | accelerator=accelerator, |
