Source code for eolearn.visualization.utils

"""
The module provides some utility functions for plotting

Copyright (c) 2017- Sinergise and contributors
For the full list of contributors, see the CREDITS file in the root directory of this source tree.

This source code is licensed under the MIT license, see the LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import itertools

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Colormap


[docs]def plot_confusion_matrix( confusion_matrix: np.ndarray, classes: list[str], normalize: bool = True, title: str = "Confusion matrix", cmap: str | Colormap | None = plt.cm.Blues, xlabel: str = "Predicted label", ylabel: str = "True label", ) -> None: """Make a single confusion matrix plot.""" if normalize: normalisation_factor = confusion_matrix.sum(axis=1)[:, np.newaxis] + np.finfo(float).eps confusion_matrix = confusion_matrix.astype(float) / normalisation_factor plt.imshow(confusion_matrix, interpolation="nearest", cmap=cmap, vmin=0, vmax=1) plt.title(title, fontsize=20) tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=90, fontsize=20) plt.yticks(tick_marks, classes, fontsize=20) fmt = ".2f" if normalize else "d" threshold = confusion_matrix.max() / 2.0 for i, j in itertools.product(range(confusion_matrix.shape[0]), range(confusion_matrix.shape[1])): plt.text( j, i, format(confusion_matrix[i, j], fmt), horizontalalignment="center", fontsize=12, color="white" if confusion_matrix[i, j] > threshold else "black", ) plt.tight_layout() plt.ylabel(ylabel, fontsize=20) plt.xlabel(xlabel, fontsize=20)