Source code for bacpipe.embedding_evaluation.probing.probe

import logging
import json
import numpy as np
import torch

from pathlib import Path

import bacpipe
logger = logging.getLogger(__name__)

from .train_probe import train_probe, LinearProbe
from .evaluate_probe import eval_probe
from .dataset_probe import generate_annotations_for_probing_task


[docs] def embeds_array_without_noise(embeds, ground_truth, label_column, **kwargs): if len(ground_truth[f"label:{label_column}"].shape) > 1: bool_array = np.any(ground_truth[f"label:{label_column}"] > -1, axis=1) else: bool_array = ground_truth[f"label:{label_column}"] > -1 if isinstance(embeds, np.ndarray): return embeds[bool_array] elif isinstance(embeds, dict): return np.concatenate(list(embeds.values()))[ bool_array ]
[docs] def probing_pipeline( model_name, ground_truth, embeds, paths=None, name='linear', overwrite=True, label_column=bacpipe.settings.label_column, **kwargs ): """ Probing pipeline consisting of building the classifier, evaluating it and saving metrics and plots of performance. Parameters ---------- paths : SimpleNamespace object dict with attributes corresponding to paths for loading and saving embeds : np.array embeddings name : string Type of Probing dataset_csv_path : string name of Probing dataframe as specified in settings.yaml overwrite : bool overwrite existing Probing?, defaults to False """ if not kwargs: kwargs = {**vars(bacpipe.settings)} kwargs.pop('label_column') if not paths: get_paths_func = bacpipe.make_set_paths_func( bacpipe.config.audio_dir, bacpipe.settings.main_results_dir ) paths = get_paths_func(model_name) if ( overwrite or name=='knn' or not paths.probe_path.joinpath(f"probe_results_{name}.json").exists() ): df = generate_annotations_for_probing_task( ground_truth, paths, label_column=label_column, **kwargs ) if len(df) == 0: logger.exception( "Not enough data in annotations to perform probing task" ) return None embeds = embeds_array_without_noise( embeds, ground_truth, label_column=label_column, **kwargs ) if not len(embeds) > 0: error = ( "\nNo embeddings were found for classification task. " "Are you sure there are annotations for the data and the annotations.csv file " "has been correctly linked? If you didn't intent do do classification, " "simply remove it from the evaluation tasks list in the config.yaml file." ) logger.exception(error) raise AssertionError(error) label2index = {label: i for i, label in enumerate(df.label.unique())} probe = train_probe(embeds, df, label2index, config=name, **kwargs) metrics = eval_probe( probe, embeds, df, label2index, config=name, paths=paths, **kwargs ) else: logger.info( f"Classification file probe_results_{name}.json already exists and" " so is not computed. If you want to overwrite existing results, " "set overwrite to True in config.yaml." ) from bacpipe.embedding_evaluation.probing.train_probe import LinearProbe state_dict = torch.load(paths.probe_path / f"{name}_probe.pt") probe = LinearProbe( in_dim=embeds.shape[-1], out_dim=list(state_dict.values())[-1].shape[0], **kwargs ) probe.load_state_dict(state_dict=state_dict) with open(paths.probe_path / "label2index.json", "r") as f: label2index = json.load(f) load_path = paths.probe_path.joinpath(f"probe_results_{name}.json") with open(load_path, "r") as f: metrics = json.load(f) return probe, label2index, metrics
[docs] def prepare_probe_inference(model, probe_path=''): from bacpipe import config, settings if probe_path == '': import bacpipe.embedding_evaluation.label_embeddings as le path_func = le.make_set_paths_func( config.audio_dir, settings.main_results_dir, settings.dim_reduc_parent_dir ) probe_path = ( path_func(model).probe_path / 'linear_probe.pt' ).as_posix() with open(Path(probe_path).parent / 'label2index.json', 'r') as f: label2index = json.load(f) probe_weights = torch.load(probe_path, map_location=settings.device) probe = LinearProbe( probe_weights['probe.weight'].shape[-1], len(label2index) ) probe.load_state_dict(probe_weights) probe.to(settings.device) return probe, label2index
[docs] def run_probe_inference( model, linear_probe, threshold, embeds=None, return_binary_presence=True, callbacks=None ): if embeds is None: from bacpipe.core.experiment_manager import Loader from bacpipe import config, settings ld = Loader( audio_dir=config.audio_dir, model_name=model, **vars(settings) ) embeds = torch.Tensor(ld.embeddings(return_type='array')).to(settings.device) import torch.nn.functional as F return_values = [] for idx, batch in enumerate(embeds): logits = linear_probe(batch) probabilities = F.softmax(logits, dim=0).detach().cpu().numpy() if return_binary_presence: binary_presence = np.zeros(probabilities.shape, dtype=np.int8) binary_presence[probabilities > threshold] = 1 return_values.append(binary_presence.tolist()) return_dtype = np.int8 else: return_values.append(probabilities.tolist()) return_dtype = np.float32 if isinstance(callbacks, dict) and hasattr(callbacks, 'progress_bar'): callbacks.progress_bar.value = int((idx+1)/len(embeds)*100) return np.array(return_values, dtype=return_dtype)