summaryrefslogtreecommitdiffstats
path: root/training/lr.py
blob: a75078f6d828cff381177b82a6977b528d6babfa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from pathlib import Path

import matplotlib.pyplot as plt


def plot_metrics(
    metrics: tuple[list[float], list[float], list[float]],
    output_file: Path,
    skip_start: int = 10,
    skip_end: int = 5,
):
    lrs, losses, accs = metrics

    if skip_end == 0:
        lrs = lrs[skip_start:]
        losses = losses[skip_start:]
        accs = accs[skip_start:]
    else:
        lrs = lrs[skip_start:-skip_end]
        losses = losses[skip_start:-skip_end]
        accs = accs[skip_start:-skip_end]

    fig, ax_loss = plt.subplots()
    ax_acc = ax_loss.twinx()

    ax_loss.plot(lrs, losses, color="red")
    ax_loss.set_xscale("log")
    ax_loss.set_xlabel(f"Learning rate")
    ax_loss.set_ylabel("Loss")

    ax_acc.plot(lrs, accs, color="blue")
    ax_acc.set_xscale("log")
    ax_acc.set_ylabel("Accuracy")

    plt.savefig(output_file, dpi=300)
    plt.close()