1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
from pathlib import Path
import matplotlib.pyplot as plt
def plot_metrics(
metrics: tuple[list[float], list[float], list[float]],
output_file: Path,
skip_start: int = 10,
skip_end: int = 5,
):
lrs, losses, accs = metrics
if skip_end == 0:
lrs = lrs[skip_start:]
losses = losses[skip_start:]
accs = accs[skip_start:]
else:
lrs = lrs[skip_start:-skip_end]
losses = losses[skip_start:-skip_end]
accs = accs[skip_start:-skip_end]
fig, ax_loss = plt.subplots()
ax_acc = ax_loss.twinx()
ax_loss.plot(lrs, losses, color="red")
ax_loss.set_xscale("log")
ax_loss.set_xlabel(f"Learning rate")
ax_loss.set_ylabel("Loss")
ax_acc.plot(lrs, accs, color="blue")
ax_acc.set_xscale("log")
ax_acc.set_ylabel("Accuracy")
plt.savefig(output_file, dpi=300)
plt.close()
|