summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
committerVolpeon <git@volpeon.ink>2023-02-08 07:27:55 +0100
commit9ea20241bbeb2f32199067096272e13647c512eb (patch)
tree9e0891a74d0965da75e9d3f30628b69d5ba3deaf /train_dreambooth.py
parentFix Lora memory usage (diff)
downloadtextual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.gz
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.tar.bz2
textual-inversion-diff-9ea20241bbeb2f32199067096272e13647c512eb.zip
Fixed Lora training
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a4c47b..a29c507 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -442,6 +442,12 @@ def main():
442 mixed_precision=args.mixed_precision 442 mixed_precision=args.mixed_precision
443 ) 443 )
444 444
445 weight_dtype = torch.float32
446 if args.mixed_precision == "fp16":
447 weight_dtype = torch.float16
448 elif args.mixed_precision == "bf16":
449 weight_dtype = torch.bfloat16
450
445 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 451 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
446 452
447 if args.seed is None: 453 if args.seed is None:
@@ -495,12 +501,6 @@ def main():
495 else: 501 else:
496 optimizer_class = torch.optim.AdamW 502 optimizer_class = torch.optim.AdamW
497 503
498 weight_dtype = torch.float32
499 if args.mixed_precision == "fp16":
500 weight_dtype = torch.float16
501 elif args.mixed_precision == "bf16":
502 weight_dtype = torch.bfloat16
503
504 trainer = partial( 504 trainer = partial(
505 train, 505 train,
506 accelerator=accelerator, 506 accelerator=accelerator,