import panel as pn
import matplotlib
import sys
import seaborn as sns
import numpy as np
import logging
logger = logging.getLogger("bacpipe")
import importlib.resources as pkg_resources
import bacpipe.imgs
from .visualize_embeddings import (
plot_embeddings,
plot_comparison,
EmbedAndLabelLoader,
)
from . import tooltips
from .visualize import (
plot_clusterings,
clustering_overview,
plot_overview_metrics,
)
from .visualize_spectrograms import SpectrogramPlot
from .visualize_predictions import (
plot_classification_results,
plot_classification_heatmap,
PredictionsLoader
)
import bacpipe.embedding_evaluation.label_embeddings as le
from .dashboard_utils import DashBoardHelper
### plotting settings
sns.set_theme(style="whitegrid")
matplotlib.use("agg")
pn.extension('plotly')
[docs]
class DashBoard(DashBoardHelper):
def __init__(
self,
model_names,
audio_dir,
main_results_dir,
default_label_keys,
evaluation_task,
dim_reduction_model,
dim_reduc_parent_dir,
**kwargs,
):
self.models = model_names
self.default_label_keys = default_label_keys
self.audio_dir = audio_dir
self.path_func = le.make_set_paths_func(
audio_dir, main_results_dir, dim_reduc_parent_dir, **kwargs
)
self.label_by = default_label_keys.copy()
if (
(self.path_func(model_names[0]).preds_path).exists()
and not "default_classifier" in self.label_by
):
clfier_paths = list(self.path_func(model_names[0])
.preds_path
.rglob('*_classifier_annotations.csv')
)
if len(clfier_paths) > 0:
if clfier_paths[0].exists():
self.label_by += ["default_classifier"]
self.plot_path = self.path_func(model_names[0]).plot_path.parent.parent
self.dim_reduc_parent_dir = dim_reduc_parent_dir
self.ground_truth = None
if (
le.get_paths(model_names[0])
.labels_path.joinpath("ground_truth.npy")
.exists()
):
ground_truth = np.load(
le.get_paths(model_names[0]).labels_path.joinpath("ground_truth.npy"),
allow_pickle=True,
).item()
labels = np.unique(
[lab.split(":")[-1] for lab in ground_truth.keys()]
).tolist()
self.ground_truth = True
self.label_by += labels
if len(list(le.get_paths(model_names[0]).clust_path.glob("*.npy"))) > 0:
self.label_by += ["kmeans"]
self.evaluation_task = evaluation_task
self.dim_reduction_model = dim_reduction_model
self.widget_width = 100
self.vis_loader = EmbedAndLabelLoader(
dim_reduction_model=dim_reduction_model,
default_label_keys=default_label_keys,
**kwargs,
)
self.interactive_embedding_plot = True
self.model_select = dict()
self.label_select = dict()
self.noise_select = dict()
self.autoplay_audio_select = dict()
self.clfier_select = dict()
self.species_select = dict()
self.accumulate_select = dict()
self.class_select = dict()
self.embed_plot = dict()
self.embed_save_button = dict()
self.embed_notification = dict()
self.interactive_embed_plot = dict()
self.spectrogram_plot_panel = dict()
self.spec_plot_obj = dict()
self._trigger_spec_obj_update = dict()
self.class_options = dict()
self.preds_data = dict()
self.clfier_path = dict()
self.clfier_thresh = dict()
self.btn_run_clfier = dict()
self.progress_bar = dict()
self.trigger_classification = dict()
self.loading_test_placeholder = dict()
self.heatmap_plot = dict()
self.kwargs = kwargs
[docs]
def embedding_panel(self, widget_idx=0):
if not self.interactive_embedding_plot:
embedding_plot = self.init_plot(
# self.init_interactive_plot(
"embed",
plot_embeddings,
widget_idx,
loader=self.vis_loader,
model_name=self.model_select[widget_idx],
label_by=self.label_select[widget_idx],
ground_truth=self.ground_truth,
dim_reduction_model=self.dim_reduction_model,
remove_noise=(
self.noise_select[widget_idx]
if len(self.noise_select.keys()) > 0
else False
),
dashboard=True,
dashboard_idx=widget_idx,
)
else:
self.init_interactive_embed_plot(widget_idx)
embedding_plot = pn.bind(
self.update_main_plot,
"interactive_embed",
plot_embeddings,
widget_idx,
loader=self.vis_loader,
model_name=self.model_select[widget_idx],
label_by=self.label_select[widget_idx],
ground_truth=self.ground_truth,
dim_reduction_model=self.dim_reduction_model,
remove_noise=(
self.noise_select[widget_idx]
if len(self.noise_select.keys()) > 0
else False
),
dashboard=True,
dashboard_idx=widget_idx,
)
return (
"2D Embedding Plot",
pn.Column(
embedding_plot,
self.embed_save_button[widget_idx],
self.embed_notification[widget_idx],
)
)
[docs]
def spectrogram_panel(self, widget_idx=0):
self.spectrogram_plot_panel[widget_idx] = pn.pane.Plotly(
SpectrogramPlot.dummy_image(title=''),
sizing_mode='stretch_width',
height=self.kwargs.get('spectrogram_plot_height')
)
embedding_info_dialogue = pn.widgets.StaticText(
value="", width=self.kwargs.get('accordion_width')-80,
)
self.spec_plot_obj[widget_idx] = SpectrogramPlot(
self.audio_dir, self.vis_loader,
self.model_select[widget_idx],
embedding_info_dialogue,
**self.kwargs
)
self._trigger_spec_obj_update[widget_idx] = pn.bind(
(
self.spec_plot_obj[widget_idx]._update_spec_obj
),
self.model_select[widget_idx],
self.autoplay_audio_select[widget_idx]
)
play_audio_button = pn.widgets.Button(name="Play audio", button_type="primary")
play_audio_button.on_click(self.spec_plot_obj[widget_idx].play_audio)
save_selection_dialogue = pn.widgets.StaticText(
value="", width=400
)
save_selection_button = pn.widgets.Button(
name="Save selection to file", button_type="primary"
)
save_selection_button.on_click(
lambda x: self.save_selected_points(
x, save_selection_dialogue, widget_idx
)
)
save_selection_dialogue.visible = False
return (
"Spectrogram",
pn.Column(
embedding_info_dialogue,
self.spectrogram_plot_panel[widget_idx],
save_selection_dialogue,
pn.Row(
play_audio_button,
save_selection_button
),
pn.widgets.StaticText(value="", height=80)
)
)
[docs]
def clustering_panel(self, widget_idx):
return (
"Clustering Results",
(
pn.Column(
pn.widgets.TooltipIcon(value=tooltips.clustering),
self.plot_widget(
plot_clusterings,
path_func=self.path_func,
model_name=self.model_select[widget_idx],
label_by=self.label_select[widget_idx],
no_noise=(
self.noise_select[widget_idx]
if len(self.noise_select.keys()) > 0
else False
),
)
if "clustering" in self.evaluation_task
else pn.pane.Markdown(
"No clustering task specified. "
"Please check the config file."
)
)
)
)
[docs]
def probing_panel(self, widget_idx):
return (
"Probing Performance",
(
pn.Column(
pn.widgets.TooltipIcon(value=tooltips.probing),
self.plot_widget(
plot_classification_results,
path_func=self.path_func,
task_name=self.class_select[widget_idx],
model_name=self.model_select[widget_idx],
return_fig=True,
)
if "probing" in self.evaluation_task
else pn.pane.Markdown(
"No probing task specified. "
"Please check the config file."
)
)
)
)
[docs]
def model_page(self, widget_idx, single_model=False):
sidebar = self.make_sidebar(widget_idx, model=True)
title_string = "Model Dashboard for {}".format
accordion_title = pn.bind(title_string,
self.model_select[widget_idx]
)
if single_model:
data_panels = pn.Row(
pn.Accordion(
self.embedding_panel(widget_idx),
active=[0],
width=self.kwargs.get('accordion_width'),
),
pn.Accordion(
self.spectrogram_panel(widget_idx),
self.clustering_panel(widget_idx),
self.probing_panel(widget_idx),
active=[0, 1, 2],
)
)
else:
data_panels = pn.Accordion(
self.embedding_panel(widget_idx),
self.spectrogram_panel(widget_idx),
self.clustering_panel(widget_idx),
self.probing_panel(widget_idx),
active=[0, 1, 2, 3],
width=self.kwargs.get('accordion_width'),
)
main_content = pn.Column(
pn.widgets.StaticText(
value=accordion_title,
styles={
'font-size': '1.5em', # Equivalent to a standard H2
'font-weight': 'bold',
'margin-top': '0px',
'margin-bottom': '15px'
}
),
data_panels,
# width=self.kwargs.get('accordion_width'),
# sizing_mode="stretch_both",
)
return pn.Row(sidebar, main_content) # , sizing_mode="stretch_both")
[docs]
def all_models_page(self, widget_idx):
sidebar = self.make_sidebar(widget_idx, model=False, all_models=True)
main_content = pn.Column(
pn.pane.Markdown("## All Models Dashboard"),
pn.Accordion(
(
"Embedding Comparison",
self.init_plot(
"embed",
plot_comparison,
widget_idx,
loader=self.vis_loader,
plot_path=self.plot_path,
models=self.models,
dim_reduction_model=self.dim_reduction_model,
label_by=self.label_select[widget_idx],
remove_noise=(
self.noise_select[widget_idx]
if len(self.noise_select.keys()) > 0
else False
),
default_label_keys=self.default_label_keys,
dashboard=True,
),
),
(
"Clustering Overview",
(
pn.Column(
pn.widgets.TooltipIcon(value=tooltips.clustering),
self.plot_widget(
clustering_overview,
path_func=self.path_func,
model_list=self.models,
label_by=self.label_select[widget_idx],
no_noise=(
self.noise_select[widget_idx]
if len(self.noise_select.keys()) > 0
else False
),
**self.kwargs
)
if "clustering" in self.evaluation_task
else pn.pane.Markdown(
"No clustering task specified. "
"Please check the config file."
)
)
),
),
(
"Probing Metrics",
(
self.plot_widget(
plot_overview_metrics,
plot_path=None,
metrics=None,
task_name=self.class_select[widget_idx],
path_func=self.path_func,
model_list=self.models,
return_fig=True,
)
if "probing" in self.evaluation_task
else pn.pane.Markdown(
"No probing task specified. "
"Please check the config file."
)
),
),
# sizing_mode="stretch_width",
active=[0, 1, 2],
),
width=2 * self.kwargs.get('accordion_width'),
# sizing_mode="stretch_both",
)
return pn.Row(sidebar, main_content) # , sizing_mode="stretch_both")
[docs]
def apply_clfier_page(self, widget_idx):
self.class_options[widget_idx] = []
sidebar = self.make_sidebar(widget_idx, model=True, classifier_page=True)
# input box where i can input the path to the linear classifier
self.clfier_path[widget_idx] = pn.widgets.TextInput(
name='Path to Linear Probe',
placeholder=(
self.path_func(self.models[0]).probe_path / 'linear_probe.pt'
).as_posix(),
width=600,
max_length=800,
visible=False
)
self.clfier_thresh[widget_idx] = pn.widgets.TextInput(
name='Threshold for classification',
placeholder='0.5',
width=80,
)
self.btn_run_clfier[widget_idx] = pn.widgets.Button(
# name='Apply linear classifier',
name='Load predictions from integrated classifier',
width=100,
height=30,
)
self.progress_bar[widget_idx] = pn.indicators.Progress(
value=0,
max=100,
bar_color='primary',
width=500
)
self.loading_test_placeholder[widget_idx] = pn.widgets.StaticText(
name='Preparing classification',
value=''
)
self.clfier_select[widget_idx].param.watch(
lambda x: self.change_input_options(x, widget_idx=widget_idx),
'value'
)
self.preds_data[widget_idx] = PredictionsLoader(
self.vis_loader,
self.path_func,
self.models,
panel_selection=self.species_select[widget_idx],
progress_bar=self.progress_bar[widget_idx],
loading_pane=self.loading_test_placeholder[widget_idx]
)
self.btn_run_clfier[widget_idx].on_click(
lambda x: self.update_main_plot(
"heatmap",
plot_classification_heatmap,
widget_idx=widget_idx,
event=x,
predictions_loader=self.preds_data[widget_idx],
model=self.model_select[widget_idx],
accumulate_by=self.accumulate_select[widget_idx],
species=self.species_select[widget_idx],
threshold=self.clfier_thresh[widget_idx],
clfier_path=self.clfier_path[widget_idx],
clfier_type=self.clfier_select[widget_idx],
**self.kwargs
)
)
main_content = pn.Column(
pn.pane.Markdown("## All Models Dashboard"),
pn.Accordion(
(
"Classification settings",
pn.Column(
# trigger_input_options,
self.clfier_path[widget_idx],
# after that show me the classes that this
# linear classifier will classify
pn.widgets.StaticText(
name='Classes',
value=pn.bind(
self.preds_data[widget_idx].get_classes,
self.clfier_path[widget_idx]
)
),
# input section to give a threshold for classification
self.clfier_thresh[widget_idx],
# button to click run
self.btn_run_clfier[widget_idx],
# placeholder textbox to show that something
# is happening while waiting on embeddings to load
self.loading_test_placeholder[widget_idx],
# progbar
self.progress_bar[widget_idx],
)
),
(
"Classification heatmap",
self.init_plot(
"heatmap",
plot_classification_heatmap,
widget_idx=widget_idx,
event=None,
predictions_loader=self.preds_data[widget_idx],
model=self.model_select[widget_idx],
accumulate_by=self.accumulate_select[widget_idx],
species=self.species_select[widget_idx],
threshold=self.clfier_thresh[widget_idx],
clfier_type=self.clfier_select[widget_idx],
**self.kwargs
)
),
active=[0, 1, 2],
# by default create all annotations as one big annotations file
# # add button to save as raven annotations
),
width=self.kwargs.get('accordion_width')
)
return pn.Row(sidebar, main_content) # , sizing_mode="stretch_both")
[docs]
def build_layout(self):
"""
Builds the layout for the dashboard with two models and a single model page.
The layout consists of a single model page, a two-models comparison page,
and a page showing all models. Each page contains sidebars with model-specific
information and content areas for visualizations.
"""
# Build both model pages to initialize widgets
model0_page = self.model_page(0, single_model=True)
model1_page = self.model_page(1)
model2_page = self.model_page(2)
model_all_page = self.all_models_page(1)
apply_classifier0_page = self.apply_clfier_page(0)
apply_classifier1_page = self.apply_clfier_page(1)
# Extract sidebars and content
sidebar0, content0 = model0_page.objects
sidebar1, content1 = model1_page.objects
sidebar2, content2 = model2_page.objects
sidebar3, content3 = apply_classifier0_page.objects
sidebar4, content4 = apply_classifier1_page.objects
# Wrap sidebars with titles
sidebar0 = pn.Column(
pn.pane.Markdown("## Model 1"), sidebar0 # , sizing_mode="stretch_height"
)
sidebar1 = pn.Column(
pn.pane.Markdown("## Model 2"), sidebar1 # , sizing_mode="stretch_height"
)
self.app = pn.Tabs(
("Single model", model0_page),
(
"Two models",
pn.Row(
pn.Column(sidebar1, sidebar2),
pn.Row(content1, content2),
sizing_mode="stretch_both",
),
),
("All models", model_all_page),
("Single Model Predictions", apply_classifier1_page),
(
"Two Model Predictions",
pn.Row(
pn.Column(sidebar3, sidebar4),
pn.Row(content3, content4),
sizing_mode="stretch_both",
),
),
)
self.add_styling(model0_page, model2_page, model_all_page, apply_classifier1_page)
[docs]
def add_styling(self, *pages):
with pkg_resources.path(bacpipe.imgs, 'bacpipe_unlabelled.png') as p:
logo_path = str(p)
for page in pages:
sidebar = page.objects[0]
# Add logo to the sidebar
sidebar.append(
pn.pane.PNG(logo_path, sizing_mode="scale_width")
)
# Add a spacer + contact info below the logo
sidebar.append(pn.Spacer(height=20))
sidebar.append(
pn.pane.Markdown(
"""
**Contact**
If you run into problems, please raise issues on github
Please collaborate and help make bacpipe as convenient for many as possible
🌍 [github](https://github.com/bioacoustic-ai/bacpipe)
To stay updated with new releases, subscribe to the [newsletter](https://buttondown.com/vskode)
"""
)
)
# Add close button to the header
close_button = pn.widgets.Button(name="❌ close dashboard")
def shutdown_callback(event):
logger.info("Shutting down dashboard server...")
sys.exit(0)
close_button.on_click(shutdown_callback)
sidebar.append(close_button)
[docs]
def visualize_using_dashboard(
models,
dashboard_port=5006,
dashboard_address='localhost',
dashboard_websocket_origin=False,
**kwargs
):
"""
Create and serve the dashboard for visualization. To colorcode embeddings
by other labels than the default ones, create an annotations file with timestamps.
An example file can be found in 'bacpipe/tests/test_data/annotations.csv'.
Multiple dashboards can be opened, the port will simply increment.
Parameters
----------
models : list
embedding models
kwargs : dict
Dictionary with parameters for dashboard creation
"""
from bacpipe.embedding_evaluation.visualization.dashboard import DashBoard
import panel as pn
# Configure dashboard
dashboard = DashBoard(models, **kwargs)
# Build the dashboard layout
try:
dashboard.build_layout()
except Exception as e:
logger.exception(
f"\nError building dashboard layout: {e}\n \n "
"Are you sure all the evaluations have been performed? "
"If not, rerun the pipeline with `overwrite=True`.\n \n "
)
raise e
with pkg_resources.path(bacpipe.imgs, 'bacpipe_favicon_white.png') as p:
favicon_path = str(p)
template = pn.template.BootstrapTemplate(
site="bacpipe dashboard",
title="Explore embeddings of audio data",
favicon=str(favicon_path), # must be a path ending in .ico, .png, etc.
main=[dashboard.app],
)
if not dashboard_websocket_origin is None:
websocket_origin = dashboard_websocket_origin
else:
websocket_origin = None
port_not_available = True
while port_not_available:
try:
template.show(
port=dashboard_port,
address=dashboard_address,
websocket_origin=websocket_origin
)
port_not_available = False
except OSError:
logger.warning(
f"The port {dashboard_port} is already in use. This "
"is most likely the case because you already have a "
"dashboard open. There is a exit button in the bottom "
"left of the dashboard. If this was intentional and you "
"want to open multiple dashboards at once, ignore this message."
)
dashboard_port += 1