from pathlib import Path import matplotlib.pyplot as plt def plot_metrics( metrics: tuple[list[float], list[float], list[float]], output_file: Path, skip_start: int = 10, skip_end: int = 5, ): lrs, losses, accs = metrics if skip_end == 0: lrs = lrs[skip_start:] losses = losses[skip_start:] accs = accs[skip_start:] else: lrs = lrs[skip_start:-skip_end] losses = losses[skip_start:-skip_end] accs = accs[skip_start:-skip_end] fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() ax_loss.plot(lrs, losses, color="red") ax_loss.set_xscale("log") ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") ax_acc.plot(lrs, accs, color="blue") ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") plt.savefig(output_file, dpi=300) plt.close()