diff options
author | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
commit | 94b676d91382267e7429bd68362019868affd9d1 (patch) | |
tree | 513697739ab25217cbfcff630299d02b1f6e98c8 /train_lora.py | |
parent | Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff) | |
download | textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2 textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index 5fd05cc..a8c1cf6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -392,7 +392,7 @@ def main(): | |||
392 | args = parse_args() | 392 | args = parse_args() |
393 | 393 | ||
394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 395 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
396 | output_dir.mkdir(parents=True, exist_ok=True) | 396 | output_dir.mkdir(parents=True, exist_ok=True) |
397 | 397 | ||
398 | accelerator = Accelerator( | 398 | accelerator = Accelerator( |
@@ -408,7 +408,7 @@ def main(): | |||
408 | elif args.mixed_precision == "bf16": | 408 | elif args.mixed_precision == "bf16": |
409 | weight_dtype = torch.bfloat16 | 409 | weight_dtype = torch.bfloat16 |
410 | 410 | ||
411 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 411 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
412 | 412 | ||
413 | if args.seed is None: | 413 | if args.seed is None: |
414 | args.seed = torch.random.seed() >> 32 | 414 | args.seed = torch.random.seed() >> 32 |
@@ -489,8 +489,8 @@ def main(): | |||
489 | prior_loss_weight=args.prior_loss_weight, | 489 | prior_loss_weight=args.prior_loss_weight, |
490 | ) | 490 | ) |
491 | 491 | ||
492 | checkpoint_output_dir = output_dir.joinpath("model") | 492 | checkpoint_output_dir = output_dir / "model" |
493 | sample_output_dir = output_dir.joinpath(f"samples") | 493 | sample_output_dir = output_dir/"samples" |
494 | 494 | ||
495 | datamodule = VlpnDataModule( | 495 | datamodule = VlpnDataModule( |
496 | data_file=args.train_data_file, | 496 | data_file=args.train_data_file, |
@@ -562,7 +562,7 @@ def main(): | |||
562 | sample_image_size=args.sample_image_size, | 562 | sample_image_size=args.sample_image_size, |
563 | ) | 563 | ) |
564 | 564 | ||
565 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 565 | plot_metrics(metrics, output_dir/"lr.png") |
566 | 566 | ||
567 | 567 | ||
568 | if __name__ == "__main__": | 568 | if __name__ == "__main__": |