diff options
author | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-08 07:27:55 +0100 |
commit | 9ea20241bbeb2f32199067096272e13647c512eb (patch) | |
tree | 9e0891a74d0965da75e9d3f30628b69d5ba3deaf /train_ti.py | |
parent | Fix Lora memory usage (diff) | |
download | textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2 textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip |
Fixed Lora training
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 56f9e97..2840def 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -513,6 +513,12 @@ def main(): | |||
513 | mixed_precision=args.mixed_precision | 513 | mixed_precision=args.mixed_precision |
514 | ) | 514 | ) |
515 | 515 | ||
516 | weight_dtype = torch.float32 | ||
517 | if args.mixed_precision == "fp16": | ||
518 | weight_dtype = torch.float16 | ||
519 | elif args.mixed_precision == "bf16": | ||
520 | weight_dtype = torch.bfloat16 | ||
521 | |||
516 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
517 | 523 | ||
518 | if args.seed is None: | 524 | if args.seed is None: |
@@ -564,12 +570,6 @@ def main(): | |||
564 | else: | 570 | else: |
565 | optimizer_class = torch.optim.AdamW | 571 | optimizer_class = torch.optim.AdamW |
566 | 572 | ||
567 | weight_dtype = torch.float32 | ||
568 | if args.mixed_precision == "fp16": | ||
569 | weight_dtype = torch.float16 | ||
570 | elif args.mixed_precision == "bf16": | ||
571 | weight_dtype = torch.bfloat16 | ||
572 | |||
573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") |
574 | 574 | ||
575 | trainer = partial( | 575 | trainer = partial( |