diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 09:09:46 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 09:09:46 +0200 |
commit | d952d467d31786f4a85cc4cb009934cd4ebbba71 (patch) | |
tree | 68c3f597a86ef3b98734d80cc783aa1f42fe1a41 /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.tar.gz textual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.tar.bz2 textual-inversion-diff-d952d467d31786f4a85cc4cb009934cd4ebbba71.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 303 |
1 files changed, 261 insertions, 42 deletions
diff --git a/train_lora.py b/train_lora.py index 1ca56d9..39bf455 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | from functools import partial | 6 | from functools import partial |
6 | import math | 7 | import math |
@@ -17,9 +18,10 @@ import transformers | |||
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
21 | from training.lr import plot_metrics | 22 | from training.lr import plot_metrics |
22 | from training.strategy.lora import lora_strategy | 23 | from training.strategy.lora import lora_strategy |
24 | from training.strategy.ti import textual_inversion_strategy | ||
23 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
24 | from training.util import save_args | 26 | from training.util import save_args |
25 | 27 | ||
@@ -81,6 +83,43 @@ def parse_args(): | |||
81 | help="The name of the current project.", | 83 | help="The name of the current project.", |
82 | ) | 84 | ) |
83 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--placeholder_tokens", | ||
87 | type=str, | ||
88 | nargs='*', | ||
89 | help="A token to use as a placeholder for the concept.", | ||
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--initializer_tokens", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="A token to use as initializer word." | ||
96 | ) | ||
97 | parser.add_argument( | ||
98 | "--initializer_noise", | ||
99 | type=float, | ||
100 | default=0, | ||
101 | help="Noise to apply to the initializer word" | ||
102 | ) | ||
103 | parser.add_argument( | ||
104 | "--alias_tokens", | ||
105 | type=str, | ||
106 | nargs='*', | ||
107 | default=[], | ||
108 | help="Tokens to create an alias for." | ||
109 | ) | ||
110 | parser.add_argument( | ||
111 | "--inverted_initializer_tokens", | ||
112 | type=str, | ||
113 | nargs='*', | ||
114 | help="A token to use as initializer word." | ||
115 | ) | ||
116 | parser.add_argument( | ||
117 | "--num_vectors", | ||
118 | type=int, | ||
119 | nargs='*', | ||
120 | help="Number of vectors per embedding." | ||
121 | ) | ||
122 | parser.add_argument( | ||
84 | "--exclude_collections", | 123 | "--exclude_collections", |
85 | type=str, | 124 | type=str, |
86 | nargs='*', | 125 | nargs='*', |
@@ -187,6 +226,16 @@ def parse_args(): | |||
187 | default=2000 | 226 | default=2000 |
188 | ) | 227 | ) |
189 | parser.add_argument( | 228 | parser.add_argument( |
229 | "--num_pti_epochs", | ||
230 | type=int, | ||
231 | default=None | ||
232 | ) | ||
233 | parser.add_argument( | ||
234 | "--num_pti_steps", | ||
235 | type=int, | ||
236 | default=500 | ||
237 | ) | ||
238 | parser.add_argument( | ||
190 | "--gradient_accumulation_steps", | 239 | "--gradient_accumulation_steps", |
191 | type=int, | 240 | type=int, |
192 | default=1, | 241 | default=1, |
@@ -258,6 +307,12 @@ def parse_args(): | |||
258 | help="Initial learning rate (after the potential warmup period) to use.", | 307 | help="Initial learning rate (after the potential warmup period) to use.", |
259 | ) | 308 | ) |
260 | parser.add_argument( | 309 | parser.add_argument( |
310 | "--learning_rate_pti", | ||
311 | type=float, | ||
312 | default=1e-4, | ||
313 | help="Initial learning rate (after the potential warmup period) to use.", | ||
314 | ) | ||
315 | parser.add_argument( | ||
261 | "--scale_lr", | 316 | "--scale_lr", |
262 | action="store_true", | 317 | action="store_true", |
263 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 318 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -433,6 +488,23 @@ def parse_args(): | |||
433 | help="The weight of prior preservation loss." | 488 | help="The weight of prior preservation loss." |
434 | ) | 489 | ) |
435 | parser.add_argument( | 490 | parser.add_argument( |
491 | "--use_emb_decay", | ||
492 | action="store_true", | ||
493 | help="Whether to use embedding decay." | ||
494 | ) | ||
495 | parser.add_argument( | ||
496 | "--emb_decay_target", | ||
497 | default=0.4, | ||
498 | type=float, | ||
499 | help="Embedding decay target." | ||
500 | ) | ||
501 | parser.add_argument( | ||
502 | "--emb_decay", | ||
503 | default=1e+2, | ||
504 | type=float, | ||
505 | help="Embedding decay factor." | ||
506 | ) | ||
507 | parser.add_argument( | ||
436 | "--max_grad_norm", | 508 | "--max_grad_norm", |
437 | default=1.0, | 509 | default=1.0, |
438 | type=float, | 510 | type=float, |
@@ -464,6 +536,40 @@ def parse_args(): | |||
464 | if args.project is None: | 536 | if args.project is None: |
465 | raise ValueError("You must specify --project") | 537 | raise ValueError("You must specify --project") |
466 | 538 | ||
539 | if isinstance(args.placeholder_tokens, str): | ||
540 | args.placeholder_tokens = [args.placeholder_tokens] | ||
541 | |||
542 | if isinstance(args.initializer_tokens, str): | ||
543 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
544 | |||
545 | if len(args.initializer_tokens) == 0: | ||
546 | raise ValueError("You must specify --initializer_tokens") | ||
547 | |||
548 | if len(args.placeholder_tokens) == 0: | ||
549 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
550 | |||
551 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
552 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
553 | |||
554 | if isinstance(args.inverted_initializer_tokens, str): | ||
555 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | ||
556 | |||
557 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | ||
558 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
559 | args.initializer_tokens += args.inverted_initializer_tokens | ||
560 | |||
561 | if isinstance(args.num_vectors, int): | ||
562 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
563 | |||
564 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | ||
565 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
566 | |||
567 | if args.alias_tokens is None: | ||
568 | args.alias_tokens = [] | ||
569 | |||
570 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
571 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
572 | |||
467 | if isinstance(args.collection, str): | 573 | if isinstance(args.collection, str): |
468 | args.collection = [args.collection] | 574 | args.collection = [args.collection] |
469 | 575 | ||
@@ -544,6 +650,19 @@ def main(): | |||
544 | if args.gradient_checkpointing: | 650 | if args.gradient_checkpointing: |
545 | unet.enable_gradient_checkpointing() | 651 | unet.enable_gradient_checkpointing() |
546 | 652 | ||
653 | if len(args.alias_tokens) != 0: | ||
654 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
655 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
656 | |||
657 | added_tokens, added_ids = add_placeholder_tokens( | ||
658 | tokenizer=tokenizer, | ||
659 | embeddings=embeddings, | ||
660 | placeholder_tokens=alias_placeholder_tokens, | ||
661 | initializer_tokens=alias_initializer_tokens | ||
662 | ) | ||
663 | embeddings.persist() | ||
664 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | ||
665 | |||
547 | if args.embeddings_dir is not None: | 666 | if args.embeddings_dir is not None: |
548 | embeddings_dir = Path(args.embeddings_dir) | 667 | embeddings_dir = Path(args.embeddings_dir) |
549 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 668 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
@@ -552,6 +671,19 @@ def main(): | |||
552 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 671 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
553 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 672 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
554 | 673 | ||
674 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
675 | tokenizer=tokenizer, | ||
676 | embeddings=embeddings, | ||
677 | placeholder_tokens=args.placeholder_tokens, | ||
678 | initializer_tokens=args.initializer_tokens, | ||
679 | num_vectors=args.num_vectors, | ||
680 | initializer_noise=args.initializer_noise, | ||
681 | ) | ||
682 | stats = list(zip( | ||
683 | args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | ||
684 | )) | ||
685 | print(f"Training embeddings: {stats}") | ||
686 | |||
555 | if args.scale_lr: | 687 | if args.scale_lr: |
556 | args.learning_rate_unet = ( | 688 | args.learning_rate_unet = ( |
557 | args.learning_rate_unet * args.gradient_accumulation_steps * | 689 | args.learning_rate_unet * args.gradient_accumulation_steps * |
@@ -561,10 +693,15 @@ def main(): | |||
561 | args.learning_rate_text * args.gradient_accumulation_steps * | 693 | args.learning_rate_text * args.gradient_accumulation_steps * |
562 | args.train_batch_size * accelerator.num_processes | 694 | args.train_batch_size * accelerator.num_processes |
563 | ) | 695 | ) |
696 | args.learning_rate_pti = ( | ||
697 | args.learning_rate_pti * args.gradient_accumulation_steps * | ||
698 | args.train_batch_size * accelerator.num_processes | ||
699 | ) | ||
564 | 700 | ||
565 | if args.find_lr: | 701 | if args.find_lr: |
566 | args.learning_rate_unet = 1e-6 | 702 | args.learning_rate_unet = 1e-6 |
567 | args.learning_rate_text = 1e-6 | 703 | args.learning_rate_text = 1e-6 |
704 | args.learning_rate_pti = 1e-6 | ||
568 | args.lr_scheduler = "exponential_growth" | 705 | args.lr_scheduler = "exponential_growth" |
569 | 706 | ||
570 | if args.optimizer == 'adam8bit': | 707 | if args.optimizer == 'adam8bit': |
@@ -663,18 +800,25 @@ def main(): | |||
663 | accelerator=accelerator, | 800 | accelerator=accelerator, |
664 | unet=unet, | 801 | unet=unet, |
665 | text_encoder=text_encoder, | 802 | text_encoder=text_encoder, |
803 | tokenizer=tokenizer, | ||
666 | vae=vae, | 804 | vae=vae, |
667 | noise_scheduler=noise_scheduler, | 805 | noise_scheduler=noise_scheduler, |
668 | dtype=weight_dtype, | 806 | dtype=weight_dtype, |
807 | seed=args.seed, | ||
669 | guidance_scale=args.guidance_scale, | 808 | guidance_scale=args.guidance_scale, |
670 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 809 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
671 | no_val=args.valid_set_size == 0, | 810 | no_val=args.valid_set_size == 0, |
811 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
812 | offset_noise_strength=args.offset_noise_strength, | ||
813 | sample_scheduler=sample_scheduler, | ||
814 | sample_batch_size=args.sample_batch_size, | ||
815 | sample_num_batches=args.sample_batches, | ||
816 | sample_num_steps=args.sample_steps, | ||
817 | sample_image_size=args.sample_image_size, | ||
672 | ) | 818 | ) |
673 | 819 | ||
674 | checkpoint_output_dir = output_dir / "model" | 820 | create_datamodule = partial( |
675 | sample_output_dir = output_dir/"samples" | 821 | VlpnDataModule, |
676 | |||
677 | datamodule = VlpnDataModule( | ||
678 | data_file=args.train_data_file, | 822 | data_file=args.train_data_file, |
679 | batch_size=args.train_batch_size, | 823 | batch_size=args.train_batch_size, |
680 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
@@ -693,71 +837,146 @@ def main(): | |||
693 | train_set_pad=args.train_set_pad, | 837 | train_set_pad=args.train_set_pad, |
694 | valid_set_pad=args.valid_set_pad, | 838 | valid_set_pad=args.valid_set_pad, |
695 | seed=args.seed, | 839 | seed=args.seed, |
840 | dtype=weight_dtype, | ||
841 | ) | ||
842 | |||
843 | create_lr_scheduler = partial( | ||
844 | get_scheduler, | ||
845 | args.lr_scheduler, | ||
846 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
847 | min_lr=args.lr_min_lr, | ||
848 | warmup_func=args.lr_warmup_func, | ||
849 | annealing_func=args.lr_annealing_func, | ||
850 | warmup_exp=args.lr_warmup_exp, | ||
851 | annealing_exp=args.lr_annealing_exp, | ||
852 | cycles=args.lr_cycles, | ||
853 | end_lr=1e2, | ||
854 | warmup_epochs=args.lr_warmup_epochs, | ||
855 | mid_point=args.lr_mid_point, | ||
856 | ) | ||
857 | |||
858 | # PTI | ||
859 | # -------------------------------------------------------------------------------- | ||
860 | |||
861 | pti_output_dir = output_dir / "pti" | ||
862 | pti_checkpoint_output_dir = pti_output_dir / "model" | ||
863 | pti_sample_output_dir = pti_output_dir / "samples" | ||
864 | |||
865 | pti_datamodule = create_datamodule( | ||
866 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
867 | ) | ||
868 | pti_datamodule.setup() | ||
869 | |||
870 | num_pti_epochs = args.num_pti_epochs | ||
871 | pti_sample_frequency = args.sample_frequency | ||
872 | if num_pti_epochs is None: | ||
873 | num_pti_epochs = math.ceil( | ||
874 | args.num_pti_steps / len(pti_datamodule.train_dataset) | ||
875 | ) * args.gradient_accumulation_steps | ||
876 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | ||
877 | |||
878 | pti_optimizer = create_optimizer( | ||
879 | [ | ||
880 | { | ||
881 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
882 | "lr": args.learning_rate_pti, | ||
883 | "weight_decay": 0, | ||
884 | }, | ||
885 | ] | ||
886 | ) | ||
887 | |||
888 | pti_lr_scheduler = create_lr_scheduler( | ||
889 | optimizer=pti_optimizer, | ||
890 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
891 | train_epochs=num_pti_epochs, | ||
892 | ) | ||
893 | |||
894 | metrics = trainer( | ||
895 | strategy=textual_inversion_strategy, | ||
896 | project="ti", | ||
897 | train_dataloader=pti_datamodule.train_dataloader, | ||
898 | val_dataloader=pti_datamodule.val_dataloader, | ||
899 | optimizer=pti_optimizer, | ||
900 | lr_scheduler=pti_lr_scheduler, | ||
901 | num_train_epochs=num_pti_epochs, | ||
902 | # -- | ||
903 | sample_output_dir=pti_sample_output_dir, | ||
904 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
905 | sample_frequency=pti_sample_frequency, | ||
906 | placeholder_tokens=args.placeholder_tokens, | ||
907 | placeholder_token_ids=placeholder_token_ids, | ||
908 | use_emb_decay=args.use_emb_decay, | ||
909 | emb_decay_target=args.emb_decay_target, | ||
910 | emb_decay=args.emb_decay, | ||
911 | ) | ||
912 | |||
913 | plot_metrics(metrics, output_dir/"lr.png") | ||
914 | |||
915 | # LORA | ||
916 | # -------------------------------------------------------------------------------- | ||
917 | |||
918 | lora_output_dir = output_dir / "pti" | ||
919 | lora_checkpoint_output_dir = lora_output_dir / "model" | ||
920 | lora_sample_output_dir = lora_output_dir / "samples" | ||
921 | |||
922 | lora_datamodule = create_datamodule( | ||
696 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 923 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
697 | dtype=weight_dtype | ||
698 | ) | 924 | ) |
699 | datamodule.setup() | 925 | lora_datamodule.setup() |
700 | 926 | ||
701 | num_train_epochs = args.num_train_epochs | 927 | num_train_epochs = args.num_train_epochs |
702 | sample_frequency = args.sample_frequency | 928 | lora_sample_frequency = args.sample_frequency |
703 | if num_train_epochs is None: | 929 | if num_train_epochs is None: |
704 | num_train_epochs = math.ceil( | 930 | num_train_epochs = math.ceil( |
705 | args.num_train_steps / len(datamodule.train_dataset) | 931 | args.num_train_steps / len(lora_datamodule.train_dataset) |
706 | ) * args.gradient_accumulation_steps | 932 | ) * args.gradient_accumulation_steps |
707 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 933 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) |
708 | 934 | ||
709 | optimizer = create_optimizer( | 935 | lora_optimizer = create_optimizer( |
710 | [ | 936 | [ |
711 | { | 937 | { |
712 | "params": unet.parameters(), | 938 | "params": unet.parameters(), |
713 | "lr": args.learning_rate_unet, | 939 | "lr": args.learning_rate_unet, |
714 | }, | 940 | }, |
715 | { | 941 | { |
716 | "params": text_encoder.parameters(), | 942 | "params": itertools.chain( |
943 | text_encoder.text_model.encoder.parameters(), | ||
944 | text_encoder.text_model.final_layer_norm.parameters(), | ||
945 | ), | ||
946 | "lr": args.learning_rate_text, | ||
947 | }, | ||
948 | { | ||
949 | "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), | ||
717 | "lr": args.learning_rate_text, | 950 | "lr": args.learning_rate_text, |
951 | "weight_decay": 0, | ||
718 | }, | 952 | }, |
719 | ] | 953 | ] |
720 | ) | 954 | ) |
721 | 955 | ||
722 | lr_scheduler = get_scheduler( | 956 | lora_lr_scheduler = create_lr_scheduler( |
723 | args.lr_scheduler, | 957 | optimizer=lora_optimizer, |
724 | optimizer=optimizer, | 958 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
725 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
726 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
727 | min_lr=args.lr_min_lr, | ||
728 | warmup_func=args.lr_warmup_func, | ||
729 | annealing_func=args.lr_annealing_func, | ||
730 | warmup_exp=args.lr_warmup_exp, | ||
731 | annealing_exp=args.lr_annealing_exp, | ||
732 | cycles=args.lr_cycles, | ||
733 | end_lr=1e2, | ||
734 | train_epochs=num_train_epochs, | 959 | train_epochs=num_train_epochs, |
735 | warmup_epochs=args.lr_warmup_epochs, | ||
736 | mid_point=args.lr_mid_point, | ||
737 | ) | 960 | ) |
738 | 961 | ||
739 | metrics = trainer( | 962 | metrics = trainer( |
740 | strategy=lora_strategy, | 963 | strategy=lora_strategy, |
741 | project="lora", | 964 | project="lora", |
742 | train_dataloader=datamodule.train_dataloader, | 965 | train_dataloader=lora_datamodule.train_dataloader, |
743 | val_dataloader=datamodule.val_dataloader, | 966 | val_dataloader=lora_datamodule.val_dataloader, |
744 | seed=args.seed, | 967 | optimizer=lora_optimizer, |
745 | optimizer=optimizer, | 968 | lr_scheduler=lora_lr_scheduler, |
746 | lr_scheduler=lr_scheduler, | ||
747 | num_train_epochs=num_train_epochs, | 969 | num_train_epochs=num_train_epochs, |
748 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
749 | sample_frequency=sample_frequency, | ||
750 | offset_noise_strength=args.offset_noise_strength, | ||
751 | # -- | 970 | # -- |
752 | tokenizer=tokenizer, | 971 | sample_output_dir=lora_sample_output_dir, |
753 | sample_scheduler=sample_scheduler, | 972 | checkpoint_output_dir=lora_checkpoint_output_dir, |
754 | sample_output_dir=sample_output_dir, | 973 | sample_frequency=lora_sample_frequency, |
755 | checkpoint_output_dir=checkpoint_output_dir, | 974 | placeholder_tokens=args.placeholder_tokens, |
975 | placeholder_token_ids=placeholder_token_ids, | ||
976 | use_emb_decay=args.use_emb_decay, | ||
977 | emb_decay_target=args.emb_decay_target, | ||
978 | emb_decay=args.emb_decay, | ||
756 | max_grad_norm=args.max_grad_norm, | 979 | max_grad_norm=args.max_grad_norm, |
757 | sample_batch_size=args.sample_batch_size, | ||
758 | sample_num_batches=args.sample_batches, | ||
759 | sample_num_steps=args.sample_steps, | ||
760 | sample_image_size=args.sample_image_size, | ||
761 | ) | 980 | ) |
762 | 981 | ||
763 | plot_metrics(metrics, output_dir/"lr.png") | 982 | plot_metrics(metrics, output_dir/"lr.png") |