diff options
author | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
commit | 9ea20241bbeb2f32199067096272e13647c512eb (patch) | |
tree | 9e0891a74d0965da75e9d3f30628b69d5ba3deaf /train_dreambooth.py | |
parent | Fix Lora memory usage (diff) | |
download | textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2 textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip |
Fixed Lora training
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5a4c47b..a29c507 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -442,6 +442,12 @@ def main(): | |||
442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
443 | ) | 443 | ) |
444 | 444 | ||
445 | weight_dtype = torch.float32 | ||
446 | if args.mixed_precision == "fp16": | ||
447 | weight_dtype = torch.float16 | ||
448 | elif args.mixed_precision == "bf16": | ||
449 | weight_dtype = torch.bfloat16 | ||
450 | |||
445 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
446 | 452 | ||
447 | if args.seed is None: | 453 | if args.seed is None: |
@@ -495,12 +501,6 @@ def main(): | |||
495 | else: | 501 | else: |
496 | optimizer_class = torch.optim.AdamW | 502 | optimizer_class = torch.optim.AdamW |
497 | 503 | ||
498 | weight_dtype = torch.float32 | ||
499 | if args.mixed_precision == "fp16": | ||
500 | weight_dtype = torch.float16 | ||
501 | elif args.mixed_precision == "bf16": | ||
502 | weight_dtype = torch.bfloat16 | ||
503 | |||
504 | trainer = partial( | 504 | trainer = partial( |
505 | train, | 505 | train, |
506 | accelerator=accelerator, | 506 | accelerator=accelerator, |