Source code for bacpipe.embedding_evaluation.visualization.visualize

import json

import matplotlib.pyplot as plt
import numpy as np

import bacpipe.embedding_evaluation.label_embeddings as le
from bacpipe.embedding_evaluation.visualization.visualize_predictions import (
    load_results, plot_per_class_metrics
)
import matplotlib

from matplotlib.figure import Figure

import logging

logger = logging.getLogger(__name__)



matplotlib.rcParams.update(
    {
        "figure.dpi": 600,  # High-resolution figures
        "savefig.dpi": 600,  # Exported plot DPI
        "font.size": 12,  # Better font readability
        "axes.titlesize": 12,
        "legend.fontsize": 10,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
    }
)


[docs] def visualise_results_across_models(plot_path, task_name, model_list): """ Create visualizations to compare models by specified tasks. Parameters ---------- path_func : function return the paths when given a model name plot_path : pathlib.Path object path to overview plots task_name : str name of task model_list : list list of models """ metrics = load_results(le.get_paths, task_name, model_list) with open(plot_path.joinpath(f"{task_name}_results.json"), "w") as f: json.dump(metrics, f, indent=2) if task_name == "probing": iterate_through_subtasks( plot_per_class_metrics, plot_path, task_name, model_list, metrics ) iterate_through_subtasks( plot_overview_metrics, plot_path, task_name, model_list, metrics ) else: plot_overview_metrics(plot_path, task_name, model_list, metrics, path_func=le.get_paths)
[docs] def iterate_through_subtasks(plot_func, plot_path, task_name, model_list, metrics): """ For classification multiple subtasks exist (linear and knn). Iterate over each of the subtasks and call the plotting functions to create the visualizations. Parameters ---------- plot_func : function returns model specific paths when model name is passed plot_path : pathlib.Path object path to store overview plots task_name : str name of task model_list : list list of models metrics : dict performance dictionary """ subtasks = np.unique([s.split("(")[-1][:-1] for s in list(metrics.keys())]) for subtask in subtasks: sub_task_metrics = { k.split("(")[0]: v for k, v in metrics.items() if subtask in k } plot_func(plot_path, f"{subtask} {task_name}", model_list, sub_task_metrics)
[docs] def clustering_overview( path_func, label_by, no_noise, model_list, label_column, **kwargs ): """ Create overview plots for clustering metrics. Parameters ---------- path_func : function function to return the paths when model name is given label_by : str key of default_labels dict no_noise : bool whether to plot the metrics with or without noise model_list : list list of models label_column : str label as defined in the annotations.csv file kwargs : dict additional arguments for plotting Returns ------- plt.plot object figure handle """ fig = Figure(figsize=(12, 8)) ax = fig.subplots() fig.subplots_adjust(bottom=0.25, right=0.9) flat_metrics = dict() for model_name in model_list: with open(path_func(model_name).clust_path / "clust_results.json", "r") as f: metrics = json.load(f) if no_noise: no_noise = "_no_noise" else: no_noise = "" flat_metrics[model_name] = dict() if label_by == label_column: flat_metrics[model_name][label_column] = metrics["ARI"][ f"{label_column}{no_noise}-kmeans" ] elif not label_by == "kmeans": flat_metrics[model_name]["kmeans"] = metrics["ARI"][ f"kmeans{no_noise}-{label_by}" ] if not label_by == label_column and label_column in [ k.split("-")[0] for k in metrics["ARI"].keys() ]: flat_metrics[model_name][label_column] = metrics["ARI"][ f"{label_column}{no_noise}-{label_by}" ] return generate_bar_plot(flat_metrics, fig, ax, **kwargs)
[docs] def plot_clusterings( path_func, model_name, label_by, no_noise, fig=None, ax=None, **kwargs ): """ Plot the clustering metrics for a given model and label type. Parameters ---------- path_func : function function to return the paths when model name is given model_name : str name of model label_by : str key of default_labels dict no_noise : bool whether to plot the metrics with or without noise fig : plt.plot object, optional figure handle, by default None ax : plt.plot object, optional axes handle, by default None Returns ------- plt.plot object figure handle """ if no_noise: no_noise = "_no_noise" else: no_noise = "" clust_path = path_func(model_name).clust_path / "clust_results.json" if not clust_path.exists(): error = ( f"\nThe clustering file {clust_path} does not exist. Perhaps it was not " "created yet. To avoid getting this error set `overwrite=True`." ) logger.exception(error) raise AssertionError(error) with open(clust_path, "r") as f: metrics = json.load(f) if not fig and not ax: fig = Figure(figsize=(5, 4)) ax = fig.subplots() fig.subplots_adjust(left=0.4, bottom=0.25) keys = [ l for l in np.unique([k.split("-")[0] for k in metrics["AMI"].keys()]) if not "no_noise" in l ] flat_metrics = {k: dict() for k in keys} if label_by == "ground_truth": return None for compared_to in keys: try: flat_metrics[compared_to]["AMI"] = metrics["AMI"][ f"{compared_to+no_noise}-{label_by}" ] flat_metrics[compared_to]["ARI"] = metrics["ARI"][ f"{compared_to+no_noise}-{label_by}" ] except KeyError: flat_metrics[compared_to]["AMI"] = 0 flat_metrics[compared_to]["ARI"] = 0 return generate_bar_plot(flat_metrics, fig, ax, **kwargs)
[docs] def generate_bar_plot( metrics, fig, ax, x_label="Metric value", no_legend=False, **kwargs ): bar_height = 1 / (len(list(metrics.values())[0].keys()) + 1) cmap = plt.cm.tab10 colors = cmap(np.arange(len(list(metrics.values())[0].keys())) % cmap.N) metrics_sorted = dict(sorted(metrics.items())) for out_idx, (_, metric) in enumerate(metrics_sorted.items()): for inner_idx, (key, value) in enumerate(metric.items()): ax.barh( out_idx - bar_height * inner_idx, value, label=key, height=bar_height, color=colors[inner_idx], ) ax.set_yticks(np.arange(len(metrics_sorted.keys()))) ax.set_yticklabels(list(metrics_sorted.keys())) ax.set_xlabel(x_label) ax.vlines(0, -1, out_idx, linestyles="dashed", color="black", linewidth=0.3) hand, labl = ax.get_legend_handles_labels() if not no_legend: fig.legend( hand[: inner_idx + 1], labl[: inner_idx + 1], fontsize=10, markerscale=15, loc="outside lower center", ncol=min(len(labl), 5), ) return fig
[docs] def plot_overview_metrics( plot_path, task_name, model_list, metrics, path_func=None, return_fig=False, sort_string="kmeans-audio_file_name", ): """ Visualization of task performance by model accross all classes. Resulting plot is stored in the plot path. Parameters ---------- plot_path : pathlib.Path object path to store overview plots task_name : str name of task model_list : list list of models metrics : dict performance dictionary sort_string : str string to sort the metrics by, defaults to "kmeans-audio_file_name" """ # TODO when first ran mutliple models and then just one, metrics # doesn't know the current model and this should be caught if not metrics: res_path = path_func(model_list[0]).plot_path.parent.parent.joinpath("overview") with open(res_path.joinpath(f"probing_results.json"), "r") as f: metrics = json.load(f) metrics = { k.split("(")[0]: v["overall"] for k, v in metrics.items() if task_name in k } if "probing" in task_name: metrics = {k: v["overall"] for k, v in metrics.items()} fig = Figure(figsize=(12, 6)) ax = fig.subplots() if len(model_list) == 1 and model_list[0] not in metrics: error = ( "\nIt seems like you have selected a single model in a folder where previously " "multiple models were computed. Try selecting at least two models, that way " "this error should be fixed." ) elif not all([model in metrics for model in model_list]): raise AttributeError( "It seems like you have selected models for which the classification scores " "haven't been saved yet, but for some reason bacpipe didn't realize this. " "Try running bacpipe again with the setting `overwrite` set to `True`." ) num_metrics = len(metrics[model_list[0]]) bar_width = 1 / (num_metrics + 1) cmap = plt.cm.tab10 cols = cmap(np.arange(num_metrics) % cmap.N) if task_name == "clustering": sort_by = lambda item: list(item[-1].values())[-1][sort_string] else: sort_by = lambda item: list(item[-1].values())[0] metrics = dict(sorted(metrics.items(), key=sort_by, reverse=True)) if task_name == "clustering": metrics = { k: { k: v[sort_string] for k, v in metrics[k].items() if sort_string in v.keys() } for k, v in metrics.items() } for mod_idx, (model, d) in enumerate(metrics.items()): for i, (metric, value) in enumerate(d.items()): ax.bar( mod_idx - bar_width * i, value, label=metric, width=bar_width, color=cols[i], ) ax.set_ylabel("Various Metrics") ax.set_xlabel("Models") ax.set_xticks(np.arange(len(metrics.keys())) - bar_width * (num_metrics - 1) / 2) ax.set_xticklabels( [model.upper() for model in metrics.keys()], rotation=45, horizontalalignment="right", ) ax.set_title(f"Overall Metrics for {task_name} Across Models") fig.subplots_adjust(right=0.75, bottom=0.3) ax.legend( loc="upper left", bbox_to_anchor=(1.05, 1), title="Metrics", labels=d.keys(), fontsize=10, ) if return_fig: return fig file = ( f"overview_metrics_{task_name}_" + "-".join([m[:2] for m in metrics.keys()]) + ".png" ) plot_path.mkdir(exist_ok=True, parents=True) fig.savefig( plot_path.joinpath(file), dpi=300, ) plt.close(fig)